

def onlineSolver_poissonRg(X_init, y_init, X_del, y_del, X_add, y_add, b0, lam1, lam2):

    bt = b0.copy()
    active = np.abs(b0) > 1e-7
    bt_A = b0[active]
    X_init_A, X_del_A, X_add_A = X_init[:, active], X_del[:, active], X_add[:, active]
    init_tol = nabla1_ell(np.concatenate([X_init_A, X_del_A], axis = 0), np.concatenate([y_init, y_del], axis = 0),bt_A, lam1, lam2)

    def func(t, bt_A):
        key_mat = nabla2_ell(X_init_A, y_init, bt_A, lam2) + t * nabla2_ell(X_add_A, y_add, bt_A, lam2) + (1 - t) * nabla2_ell(X_del_A, y_del, bt_A, lam2)
        r_vec = nabla1_ell(X_del_A, y_del, bt_A, lam1, lam2) - nabla1_ell(X_add_A, y_add, bt_A, lam1, lam2)
        ans = np.linalg.inv(key_mat)@r_vec
        return ans.flatten()

    lne_search = odeint(func, bt_A, [underline_, bar_], tfirst=True)
    bt = np.zeros([2, len(b0)])
    for i in np.arange(2):
        bt[i][active] = lne_search[i]
    search_tol = nabla1_ell(np.concatenate([X_init_A, X_add_A], axis=0), np.concatenate([y_init, y_add], axis=0),bt_A, lam1, lam2)

    if ((np.linalg.norm(init_tol) + 1e-6) >np.linalg.norm(search_tol)) and np.sum((bt[0] * bt[1]) >= 0) == len(b0):
        return bt[-1]
    else:
        bt = b0.copy()
        bt_p = bt.copy()
        Ns = 50
        ds = (bar_ - underline_) / Ns
        for s in np.linspace(underline_, bar_, Ns+1)[:-1]:
            bt_A = bt[active]
            X_init_A, X_del_A, X_add_A = X_init[:, active], X_del[:, active], X_add[:, active]

            bt_A = odeint(func, bt_A, [s,s+ds], tfirst=True)[-1]
            bt = np.zeros(len(b0))
            bt[active] = bt_A
            inactive = (np.abs(bt) <= 1e-7)
            if np.sum((bt * bt_p) < 0) != 0:
                bt[(bt * bt_p) < 0] = 0
                event2 = 1
                active, inactive = np.abs(bt) > 1e-7, np.abs(bt) <= 1e-7
                bt_p = bt.copy()
                continue
            
            X_init_inA, X_del_inA, X_add_inA = X_init[:, inactive], X_del[:, inactive], X_add[:, inactive]
            bt_inA = bt[inactive]
            subdiff = X_init_inA.T@(np.exp(X_init_inA@bt_inA) - y_init) + s * X_add_inA.T@(np.exp(X_add_inA@bt_inA) - y_add) + (1 - s) * X_del_inA.T@(np.exp(X_del_inA@bt_inA) - y_del) 
            omegas = (subdiff + 2 * lam2 * bt_inA) / (- lam1)
            event1_arg = (np.abs(omegas) > 1. - 1e-7) * (np.abs(omegas) < 1. + 1e-7)
            
            if np.sum(event1_arg) != 0:
                new_val = X_init_inA[:, event1_arg].T@(1. - y_init) + s * X_add_inA[:, event1_arg].T@(1. - y_add) + (1 - s) * X_del_inA[:, event1_arg].T@(1. - y_del) 
                bt_inA[event1_arg] = new_val * 1e-7
                bt[inactive] = bt_inA
                active, inactive = np.abs(bt) > 1e-7, np.abs(bt) <= 1e-7
            
            bt_p = bt.copy()

        return bt




def onlineSolver_groupLs(X_init, y_init, X_del, y_del, X_add, y_add, b0, lam):

    bt = b0.copy()
    y_del = y_del.reshape(-1, 1)
    active = []
    active_g = []
    tmp = 0
    to_inactive = False
    to_active = False
    for ele in group:
        if ele[0] == 1:
            active_g.append(np.arange(tmp, tmp + len(ele[1:])))
            tmp = tmp + len(ele[1:])
            for i in ele[1:]:
                active.append(i)
    X_init_A = X_init[:, active]
    X_del_A = X_del[:, active]
    bt_A = bt[active]

    def func(theta, w_A):
        w_A = w_A.reshape(-1, 1)
        diag = []
        for ele in active_g:
            wk = np.sqrt(len(ele))
            t1 = w_A[ele].reshape(-1, 1)
            norms = np.linalg.norm(t1)
            a1 = wk * (norms**2 * np.eye(len(ele)) - t1@(t1.T)) / (norms**3)
            diag.append(a1)

        d2R = block_diag(*diag)  # (len(y)+1)
        ans = -np.linalg.inv(X_init_A.T@X_init_A + 2 * (sub) * lam * d2R+theta * X_del_A.T@X_del_A)@X_del_A.T@(X_del_A@(w_A) - y_del)
        return ans.flatten()

    ans = odeint(func, bt_A, [1, 0], tfirst = True)
    bt = np.zeros([2, len(b0)])
    for i in np.arange(2):
        bt[i][active] = ans[i]
    for ele in group:
        if ele[0] == 1:
            if bt[0][ele[1]] * bt[1][ele[1]] < 0:
                to_inactive = True
        if ele[0] == 0:
            if (np.linalg.norm(X_init[:, ele[1:]].T@(y_init - X_init@bt[1]).reshape(-1, 1)) / 2 / (sub)-lam * np.sqrt(len(ele[1:]))) * (
            np.linalg.norm(X_init[:, ele[1:]].T@(y_init - X_init@bt[0]).reshape(-1, 1) + (X_del[:, ele[1:]].T)@(y_del - X_del@(bt[0].reshape(-1, 1)))) / 2 / (sub)-lam * np.sqrt(len(ele[1:]))) < 0:
                to_active = True
                # print("event",ele)
                # print("monitor",np.linalg.norm(X[:,ele[1:]].T@(y-X@wt[0])) )
    # print("to_active",to_active,"to_inactive",to_inactive)
    if (not to_inactive) and (not to_active):
        b0 = bt[-1].copy()
    else:
        bt = b0.copy()
        bt_p = b0
        active = []
        active_g = []
        tmp = 0
        Ns = 100
        ds = 1. / Ns
        for ele in group:
            if ele[0] == 1:
                active_g.append(np.arange(tmp, tmp+len(ele[1:])))
                tmp = tmp+len(ele[1:])
                for i in ele[1:]:
                    active.append(i)

        for s in np.linspace(1, 0, Ns+1)[:-1]:
            X_init_A = X_init[:, active]
            X_del_A = X_del[:, active]
            bt_A = bt[active]
            bt_A = odeint(func, bt_A, [s, s-ds], tfirst=True)[-1]
            bt = np.zeros(len(b0))
            bt[active] = bt_A
            temp = 0
            active = []
            active_g = []
            for ele in group:
                if ele[0] == 1:
                    if bt_p[ele[1]] * bt[ele[1]] < 0:
                        ele[0] = 0
                    else:
                        active_g.append(np.arange(temp, temp + len(ele[1:])))
                        temp = temp + len(ele[1:])
                        for i in ele[1:]:
                            active.append(i)

                if ele[0] == 0:
                    # if ele[1]==3:
                        # tmp=np.linalg.norm(X[:,ele[1:]].T@(y-X@wt_p)+s*(x_[:,ele[1:]].T)@(y_-x_@(wt_p.reshape(-1,1)))) 
                        # print(s,"=",tmp)
                    # print("==",(x_@(wt[1].reshape(-1,1))).shape)
                    # tmp=(y_-x_@(wt[1].reshape(-1,1)))
                    tmp = np.linalg.norm(X_init[:, ele[1:]].T@(y_init - X_init@bt_p).reshape(-1, 1) + s * (X_del[:, ele[1:]].T)@(y_del - X_del@(bt_p.reshape(-1, 1)))) / 2 / (sub) - lam * np.sqrt(len(ele[1:]))
                    tmp2 = np.linalg.norm(X_init[:, ele[1:]].T@(y_init - X_init@bt).reshape(-1, 1) + (s - ds) * (X_del[:, ele[1:]].T)@(y_del - X_del@(bt.reshape(-1, 1)))) /2 / (sub) - lam * np.sqrt(len(ele[1:]))
                    if tmp * tmp2 < 0:
                        #print(t,s,'turn to events',ele)                        
                        ele[0] = 1
                        bt_ = X_init[:, ele[1:]].T@(y_init - X_init@bt).reshape(-1, 1) + s * ((X_del[:, ele[1:]].T)@(y_del - X_del@(bt.reshape(-1, 1))))
                        bt[ele[1:]] = (bt_.flatten()) * eps
                        active_g.append(np.arange(temp, temp + len(ele[1:])))
                        temp = temp + len(ele[1:])
                        for i in ele[1:]:
                            active.append(i) 
            bt_p = bt.copy()
        b0 = bt.copy()

    bt = b0.copy()
    y_add = y_add.reshape(-1,1)
    active = []
    active_g = []
    tmp = 0
    to_inactive = False
    to_active = False
    for ele in group:
        if ele[0] == 1:
            active_g.append(np.arange(tmp, tmp + len(ele[1:])))
            tmp = tmp + len(ele[1:])
            for i in ele[1:]:
                active.append(i)

    X_init_A = X_init[:, active]
    X_add_A = X_add[:, active]
    bt_A = bt[active]
    def func(theta, w_A):
        w_A = w_A.reshape(-1,1)
        diag = []
        for ele in active_g:
            wk = np.sqrt(len(ele))
            t1 = w_A[ele].reshape(-1, 1)
            norms = np.linalg.norm(t1)
            a1 = wk * (norms**2 * np.eye(len(ele)) - t1@(t1.T)) / (norms**3)
            diag.append(a1)
        d2R = block_diag(*diag)  # (len(y)+1)
        ans = -np.linalg.inv(X_init_A.T@X_init_A + 2 * (sub) * lam * d2R + theta * X_add_A.T@X_add_A)@X_add_A.T@(X_add_A@(w_A) - y_add)
        return ans.flatten()

    ans = odeint(func, bt_A, [0,1], tfirst = True)
    bt = np.zeros([2, len(b0)])
    for i in np.arange(2):
        bt[i][active] = ans[i]
    for ele in group:
        if ele[0] == 1:
            if bt[0][ele[1]] * bt[1][ele[1]] < 0:
                to_inactive = True
        if ele[0] == 0:
            # if ele[1]==3:
                # print("val",( (x_[:,ele[1:]].T)@(y_-x_@(wt[0].reshape(-1,1))) ).shape)
            if (np.linalg.norm(X_init[:, ele[1:]].T@(y_init - X_init@bt[0]).reshape(-1, 1)) / 2 / (sub) - lam * np.sqrt(len(ele[1:]))) * (
            np.linalg.norm(X_init[:, ele[1:]].T@(y_init - X_init@bt[1]).reshape(-1, 1) + (X_add[:, ele[1:]].T)@(y_add - X_add@(bt[1].reshape(-1,1)))) /2 / (sub) - lam * np.sqrt(len(ele[1:]))) < 0:
                to_active = True
    if (not to_inactive) and (not to_active):
        return bt[-1]
    else:
        bt = b0.copy()
        bt_p = b0
        active = []
        active_g = []
        tmp = 0
        Ns = 100
        ds = 1. / Ns
        for ele in group:
            if ele[0] == 1:
                active_g.append(np.arange(tmp, tmp + len(ele[1:])))
                tmp = tmp + len(ele[1:])
                for i in ele[1:]:
                    active.append(i)
        for s in np.linspace(0, 1, Ns+1)[:-1]:
            X_init_A = X_init[:, active]
            X_add_A = X_add[:, active]
            bt_A = bt[active]

            bt_A = odeint(func, bt_A, [s, s + ds], tfirst = True)[-1]
            bt = np.zeros(len(b0))
            bt[active] = bt_A
            temp = 0
            active = []
            active_g = []
            for ele in group:
                if ele[0] == 1:
                    if bt_p[ele[1]] * bt[ele[1]] < 0:
                        ele[0] = 0
                    else:
                        active_g.append(np.arange(temp, temp + len(ele[1:])))
                        temp = temp + len(ele[1:])
                        for i in ele[1:]:
                            active.append(i)
                if ele[0] == 0:
                    tmp = np.linalg.norm(X_init[:, ele[1:]].T@(y_init - X_init@bt_p).reshape(-1, 1) + s * (X_add[:, ele[1:]].T)@(y_add - X_add@(bt_p.reshape(-1, 1)))) /2 / (sub) - lam * np.sqrt(len(ele[1:]))
                    tmp2 = np.linalg.norm(X_init[:, ele[1:]].T@(y_init - X_init@bt).reshape(-1, 1) + (s + ds) * (X_add[:, ele[1:]].T)@(y_add - X_add@(bt.reshape(-1, 1)))) /2 / (sub) - lam * np.sqrt(len(ele[1:]))
                    if tmp * tmp2 < 0:                     
                        ele[0] = 1
                        bt_ = X_init[:, ele[1:]].T@(y_init - X_init@bt).reshape(-1, 1) + s * ((X_add[:, ele[1:]].T)@(y_add - X_add@(bt.reshape(-1, 1))))
                        bt[ele[1:]]=(bt_.flatten()) * eps
                        active_g.append(np.arange(temp, temp + len(ele[1:])))
                        temp = temp + len(ele[1:])					
                        for i in ele[1:]:
                            active.append(i) 
            bt_p = bt.copy()
            
        return bt