"""
    Purpose: This module implements both deterministic and stochastic variance-reduction methods to solve nolinear inclusion:
                                                    0 in G(x) + T(x).
    where G from R^p to R^p, which is L-Lipschitz continuous and satisfies a weak-Minty solution condition.
    It consists of six differents algorithms.
      1. OG      - The optimistic gradedient method
      2. VR-FR   - The forward-reflected algorithm with SVRG variance reduction - A double loop variant.
      3. LVR-FR  - The forward-reflected algorithm with SVRG variance reduction - The loopless variant.
      4. VR-FRBS - The forward-reflected-backward splitting algorithm with SVRG variance reduction.
      5. VR-EG   - The extragradient algorithm with SVRG variance reduction.
      6. SAGA-FR - The forward-reflected method with SAGA variance reduction.

    This code is prepared for the submission:
       Stochastic Variance-Reduced Forward-Reflected Methods for Root-Finding Problems, ICLR 2025 Submission.
"""

import numpy as np
from numpy import linalg as la
import matplotlib.pyplot as plt
import random
import scipy as sci

"""
Purpose: Compute an SVRG estimator for the forward-reflected operator S.
    Inputs: + data, G_opr=the operator
            + w_cur=snapshot point, x_cur=x(k), x_prv=x(k-1)
            + gamma = the parameter of forward-reflected operator
            + mb_id = the indices of mini-batch, mb_size = size of mini-batch.
    Output: S_k = SVRG( G(w), Gb(w), Gb(x(k), Gb(x(k-1))
"""
def SVRG(data, G_opr, Gw_cur, w_cur, x_cur, x_prv, gamma, mb_id, mb_size):
    Gwi_cur = G_opr(data, w_cur, mb_id, mb_size)
    Gxi_cur = G_opr(data, x_cur, mb_id, mb_size)
    Gxi_prv = G_opr(data, x_prv, mb_id, mb_size)
    return (1-gamma)*(Gw_cur - Gwi_cur) + Gxi_cur - gamma*Gxi_prv


"""
Method: An implementation of the optimistic gradient method for solving:
                            0 in G(x) + T(x).
    Inputs: + data = training data
            + G_op_eval = the function handle to evaluate G(x)
            + J_op_eval = evaluate the resolvant of eta*T
            + x0 = an initial point
            + kwargs = optional and control parameters
    Outputs: opt_sol = approximate solution, message = output message, epoch_hist = history.
"""
def OGA(data, G_op_eval, J_op_eval, x0, **kwargs):
    
    # parameters
    Lips        = data.get("L")
    gamma       = kwargs.pop('gamma', 0.5)
    eta         = kwargs.pop('eta', 2.0/Lips)
    n_max_iters = kwargs.pop('n_max_iters', 10000)
    tol         = kwargs.pop('tol', 1e-8)
    is_term     = kwargs.pop('is_term', True)
    
    # print setup
    verbose    = kwargs.pop('verbose', None)
    print_step = kwargs.pop('print_step', 20)

    # initalization 
    n, p     = data.get("n"), data.get("p")
    full_id  = range(n)
    msg      = "Initialization"
    
    # initialize the iterate vectors
    x_prev, x_cur = x0.copy(), x0.copy()
    Gx_prev  = G_op_eval(data, x_prev, full_id, n)
    op_norm0 = la.norm((x_prev - J_op_eval(data, x_prev - eta*Gx_prev, eta))/eta)
    op_norm  = op_norm0
    hist     = [ dict({"epoch":0, "error": 0, "op_norm": op_norm/op_norm0}) ]

    # print initial information
    if verbose:
        print('Solver: Optimistic Gradient Method for Generalized Equations ...')
        print(
            '{message:{fill}{align}{width}}'.format(message='', fill='=', align='^', width=42, ), '\n',
            '{message:{fill}{align}{width}}'.format(message='Epoch', fill=' ', align='^', width=7, ), '|',
            '{message:{fill}{align}{width}}'.format(message='Error', fill=' ', align='^', width=13, ), '|',
            '{message:{fill}{align}{width}}'.format(message='||G_eta(x)||', fill=' ', align='^', width=13, ), '\n',
            '{message:{fill}{align}{width}}'.format(message='', fill='-', align='^', width=42)
        )

    # main loop -- running up to max_iters iterations.
    for k in range(0, n_max_iters):
        
        # evaluate the operator G(x).
        Gx_cur  = G_op_eval(data, x_cur, full_id, n)
        
        # form the forward-relected operator S(x)
        Sx_cur  = Gx_cur - gamma*Gx_prev

        # update the next iterate
        x_next = J_op_eval( data, x_cur - eta*Sx_cur, eta )

        # compute error and operator norm.
        error   = la.norm(x_next - x_cur)
        op_norm = la.norm((x_cur - J_op_eval(data, x_cur - eta*Gx_cur, eta))/eta)

        # print every print_step iterations.
        if verbose:
            if k % print_step == 0:
                print(
                    '{:^8.0f}'.format(int(k)), '|',
                    '{:^13.3e}'.format(error), '|',
                    '{:^13.3e}'.format(op_norm/op_norm0), '|'
                )
        # save history to plot results.                
        hist.append(dict({"epoch":k, "error": error, "op_norm": op_norm/op_norm0}))

        # termination conditions.
        if is_term and op_norm <= tol*max(op_norm0, 1.0):
            msg = "Convergence acheived!"
            break

        # go to the next iteration.
        x_prev  = x_cur
        x_cur   = x_next
        Gx_prev = Gx_cur
    
    # end of the loop ...
    if verbose:
        print('{message:{fill}{align}{width}}'.format(message='', fill='=', align='^', width=42, ), '\n')
    if k+1 >= n_max_iters:
        msg = "Exceed the maximum number of epochs. Increase it to run further ..."
        
    return dict({"opt_sol": x_cur, "message": msg, "epoch_hist": hist})
    

"""
Method: An Implementation of Variance-Reduced Forward-Reflected Method for solving the following inclusion:
                                       0 in G(x) + T(x).
            This is a double-loop implememtation.
    Inputs: + data = training dataset
            + G_op_eval = the evaluation of operator G(x)
            + J_op_eval = evaluate the resolvant of eta*T
            + mb_size = mini-batch size
            + kwargs = optional and control parameters.
    Outputs: - opt_sol = approximate solution
             - message = solver messages
             - hist = history of training process.
"""
def Vr_FRA(data, G_op_eval, J_op_eval, x0, mb_size=5, **kwargs):
    
    # parameters
    Lips          = data.get("L")
    gamma         = kwargs.pop('gamma', 0.75)
    eta           = kwargs.pop('eta', 0.5/Lips)
    n_epochs      = kwargs.pop('n_epochs', 200)
    n_inner_iters = kwargs.pop('n_inner_iters', 200)
    tol           = kwargs.pop('tol', 1e-8)
    is_term       = kwargs.pop('is_term', True)
    
    # print setup
    verbose    = kwargs.pop('verbose', None)
    print_step = kwargs.pop('print_step', 1)

    # initialization
    n, p    = data.get("n"), data.get("p")
    full_id = range(n)
    msg     = "Initialization"
    
    # initialize iterate vectors.
    w_cur, x_cur, y_cur, w_prv, x_prv = x0.copy(), x0.copy(), x0.copy(), x0.copy(), x0.copy()
    
    # print initial information
    if verbose:
        print('Solver: SVRG-Forward-Reflected Method for Generalized Equations ...')
        print(
            '{message:{fill}{align}{width}}'.format(message='', fill='=', align='^', width=42, ), '\n',
            '{message:{fill}{align}{width}}'.format(message='Epoch', fill=' ', align='^', width=7, ), '|',
            '{message:{fill}{align}{width}}'.format(message='Error', fill=' ', align='^', width=13, ), '|',
            '{message:{fill}{align}{width}}'.format(message='||G_eta(x)||', fill=' ', align='^', width=13, ), '\n',
            '{message:{fill}{align}{width}}'.format(message='', fill='-', align='^', width=42)
        )

    # evaluate G(w) and its gradient mapping.
    Gw_cur   = G_op_eval(data, w_cur, full_id, n)
    op_norm  = la.norm((w_cur - J_op_eval(data, w_cur - eta*Gw_cur, eta))/eta)
    op_norm0 = op_norm
    hist     = [ dict({"epoch":0, "error": 0, "op_norm": op_norm/op_norm0}) ]
        
    # main loop -- running up to n_epochs epochs.
    for s in range(n_epochs):
        
        # initialize the inner loop.
        x_cur, y_cur, x_prv = w_cur.copy(), w_cur.copy(), w_prv.copy()
        
        # running the inner loop.
        for k in range(n_inner_iters):
            mb_id  = random.sample(full_id, mb_size) # sample a mini-batch
            # compute the SVRG estimator.
            Sx_cur = SVRG(data, G_op_eval, Gw_cur, w_cur, x_cur, x_prv, gamma, mb_id, mb_size)

            # update the iterate
            x_prv  = x_cur
            y_cur  = x_cur - eta*Sx_cur + ((2.0*gamma-1.0)/gamma)*(y_cur - x_cur)
            x_cur  = J_op_eval(data, y_cur, gamma*eta)
            
        # return to the next snapshot point
        w_prv  = w_cur    
        w_next = x_cur
        error  = la.norm(w_next - w_cur)

        # print every print_step iterations.
        if verbose:
            if s % print_step == 0:
                print(
                    '{:^8.0f}'.format(int(s)), '|',
                    '{:^13.3e}'.format(error), '|',
                    '{:^13.3e}'.format(op_norm/op_norm0), '|'
                )
        # save history to plot results.                
        hist.append(dict({"epoch":s, "error": error, "op_norm": op_norm/op_norm0}))

        # checking the termination conditions.
        if is_term and op_norm <= tol*max(op_norm0, 1.0):
            msg = "Convergence acheived!"
            break
            
        # move to the next epoch    
        w_cur   = w_next
        Gw_cur  = G_op_eval(data, w_cur, full_id, n)
        op_norm = la.norm((w_cur - J_op_eval(data, w_cur - eta*Gw_cur, eta))/eta)
    
    # end of the loop ...
    if verbose:
        print('{message:{fill}{align}{width}}'.format(message='', fill='=', align='^', width=42, ), '\n')
    if s+1 >= n_epochs:
        msg = "Exceed the maximum number of epochs. Increase it to run further ..."
        
    return dict({"opt_sol": w_cur, "message": msg, "epoch_hist": hist})
    
    
"""
Method: An Implementation of Variance-Reduced Forward-Reflected Method for solving the following inclusion:
                                        0 in G(x) + T(x).
            This is a single-loop implememtation (known as loopless SVRG).
    Inputs: + data = training dataset
            + G_op_eval = the evaluation of operator G(x)
            + J_op_eval = evaluate the resolvant of eta*T
            + mb_size = mini-batch size
            + prob = probability to update snapshot point w.
            + kwargs = optional and control parameters.
    Outputs: - opt_sol = approximate solution
             - message = solver messages
             - hist = history of training process.
"""
def LVr_FRA(data, G_op_eval, J_op_eval, x0, mb_size=5, prob=0.05, **kwargs):

    # parameters
    Lips     = data.get("L")
    gamma    = kwargs.pop('gamma', 0.75)
    eta      = kwargs.pop('eta', 0.5/Lips)
    n_epochs = kwargs.pop('n_epochs', 200)
    tol      = kwargs.pop('tol', 1e-8)
    is_term  = kwargs.pop('is_term', True)
    
    # print setup
    verbose    = kwargs.pop('verbose', None)
    print_step = kwargs.pop('print_step', 100)

    # initialization
    n, p    = data.get("n"), data.get("p")
    full_id = range(n)
    n_count = 0
    msg     = "Initialization"
    hist    = []
    n_inner_iters = int(n/mb_size)
    total_iters   = int(n_epochs*n_inner_iters)
    
    # initialize the iterate vectors
    w_cur, x_cur, y_cur, w_prv, x_prv = x0.copy(), x0.copy(), x0.copy(), x0.copy(), x0.copy()
    
    # print initial information
    if verbose:
        print('Solver: Loopless-SVRG-Forward-Reflected Method for Generalized Equations ...')
        print(
            '{message:{fill}{align}{width}}'.format(message='', fill='=', align='^', width=42, ), '\n',
            '{message:{fill}{align}{width}}'.format(message='Epoch', fill=' ', align='^', width=7, ), '|',
            '{message:{fill}{align}{width}}'.format(message='Error', fill=' ', align='^', width=13, ), '|',
            '{message:{fill}{align}{width}}'.format(message='||G_eta(x)||', fill=' ', align='^', width=13, ), '\n',
            '{message:{fill}{align}{width}}'.format(message='', fill='-', align='^', width=42)
        )

    # evalute the first full operator G(w) at the snapshot point w.
    Gw_cur     = G_op_eval(data, w_cur, full_id, n)
    op_norm    = la.norm((w_cur - J_op_eval(data, w_cur - eta*Gw_cur, eta))/eta)
    op_norm0   = op_norm
    op_norm1   = op_norm
    epoch_hist = [ dict({"epoch":0, "error": 0, "op_norm": op_norm/op_norm0}) ]
    
    # the main loop to loop over epochs.
    for k in range(total_iters):

        # sample a mini-batch
        mb_id   = random.sample(full_id, mb_size)

        # evaluate G at a snap-shot point w_cur.
        Sx_cur  = SVRG(data, G_op_eval, Gw_cur, w_cur, x_cur, x_prv, gamma, mb_id, mb_size)
        x_prv   = x_cur

        # update the main iterates
        y_cur   = x_cur - eta*Sx_cur + ((2.0*gamma - 1.0)/gamma)*(y_cur - x_cur)
        x_next  = J_op_eval(data, y_cur, gamma*eta)

        # compute error and norm of gradient mapping.
        error   = la.norm(x_next - x_cur)
        op_norm = la.norm((x_cur - J_op_eval(data, x_cur - eta*Sx_cur, eta))/eta)

        # evaluate full operator at snapshot point.
        if np.random.binomial(n=1, p=prob):
            # return to snapshot points
            w_cur    = x_next
            Gw_cur   = G_op_eval(data, w_cur, full_id, n)
            op_norm1 = la.norm((w_cur - J_op_eval(data, w_cur - eta*Gw_cur, eta))/eta)

        # only evaluate G(x) for each epoch.
        if k%n_inner_iters==0:
            n_count += 1
            epoch_hist.append(dict({"epoch": n_count, "error": error, "op_norm": op_norm1/op_norm0}))
            
        # print every print_step iterations.
        if verbose:
            if k%n_inner_iters==0: #k % print_step == 0:
                print(
                    '{:^8.0f}'.format(int(n_count)), '|',
                    '{:^13.3e}'.format(error), '|',
                    '{:^13.3e}'.format(op_norm/op_norm0), '|'
                )
        # save history to plot results. 
        hist.append(dict({"iter":k, "epoch":n_count, "error": error, "op_norm": op_norm}))

        # checking the termination conditions.
        if is_term and op_norm <= tol*max(op_norm0, 1.0):
            msg = "Convergence acheived!"
            break
            
        # move to the next iteration.
        x_cur  = x_next

    # end of the loop ...
    if verbose:
        print('{message:{fill}{align}{width}}'.format(message='', fill='=', align='^', width=42, ), '\n')
    if k+1 >= total_iters:
        msg = "Exceed the maximum number of epochs. Increase it to run further ..."

    return dict({"opt_sol": x_cur, "message": msg, "hist": hist, "epoch_hist": epoch_hist})
    

"""
Method: An Implementation of SAGA Variance-Reduced Forward-Reflected Method for solving the following inclusion:
                                    0 in G(x) + T(x).
            This is a single-loop implememtation of SAGA-FR method.
    Inputs: + data = training dataset
            + G_op_eval = the evaluation of operator G(x)
            + mb_size = mini-batch size
            + prob = probability to update snapshot point w.
            + kwargs = optional and control parameters.
    Outputs: - opt_sol = approximate solution
             - message = solver messages
             - hist = history of training process.
"""
def Saga_FRA(data, G_op_eval, Gb_op_eval, J_op_eval, x0, mb_size=5, **kwargs):

    # parameters
    Lips     = data.get("L")
    gamma    = kwargs.pop('gamma', 0.75)
    eta      = kwargs.pop('eta', 0.5/Lips)
    n_epochs = kwargs.pop('n_epochs', 200)
    tol      = kwargs.pop('tol', 1e-8)
    is_term  = kwargs.pop('is_term', True)
    
    # print setup
    verbose    = kwargs.pop('verbose', None)
    print_step = kwargs.pop('print_step', 100)

    # initialization
    n, p          = data.get("n"), data.get("p")
    full_id       = range(n)
    msg           = "Initialization"
    n_inner_iters = int(n/mb_size)
    total_iters   = int(n_epochs*n_inner_iters)
    n_count       = 0
    hist          = [] 
    
    # initalize the iterates 
    x_prv, x_cur, y_cur = x0.copy(), x0.copy(), x0.copy()
    
    # print initial information
    if verbose:
        print('Solver: SAGA-Forward-Reflected Method for Generalized Equations ...')
        print(
            '{message:{fill}{align}{width}}'.format(message='', fill='=', align='^', width=42, ), '\n',
            '{message:{fill}{align}{width}}'.format(message='Iters', fill=' ', align='^', width=7, ), '|',
            '{message:{fill}{align}{width}}'.format(message='Error', fill=' ', align='^', width=13, ), '|',
            '{message:{fill}{align}{width}}'.format(message='||G_eta(x)||', fill=' ', align='^', width=13, ), '\n',
            '{message:{fill}{align}{width}}'.format(message='', fill='-', align='^', width=42)
        )

    # evalute the first full operator G(w) at the snapshot point w.
    Gx_cur, Gop_memory  = Gb_op_eval(data, x_cur, full_id, n)
    op_norm    = la.norm((x_cur - J_op_eval(data, x_cur - eta*Gx_cur, eta))/eta)
    op_norm0   = op_norm
    epoch_hist = [ dict({"epoch":0, "error": 0, "op_norm": op_norm/op_norm0}) ]

    # check the consistence of Gop_memory.
    if not isinstance(Gop_memory, np.ndarray):
        raise ValueError("The type of the second output of G(x) is not correct!")
    if Gop_memory.shape == (n, p):
        Gop_memory = Gop_memory.T

    # initialize a table to store history of Gi(x)
    Gx_avg = np.mean(Gop_memory, axis=1)  
    
    # the main loop to loop over epochs.
    for k in range(total_iters):

        # sample a mini-batch
        mb_id   = random.sample(full_id, mb_size)
        
        # evaluate Gx(k) and Gx(k-1) for a given mini-batch.
        Gxi_prv = G_op_eval(data, x_prv, mb_id, mb_size)
        Gxi_cur, Gxi_mb = Gb_op_eval(data, x_cur, mb_id, mb_size)

        # compute Gz(k) for full data and Gz(k) for the mini-batch.
        Gzi_avg = np.mean( Gop_memory[:, mb_id], axis=1 )
        Sx_cur  = (1-gamma)*(Gx_avg - Gzi_avg) + Gxi_cur - gamma*Gxi_prv
        
        # store G(x) into a table.
        Gop_memory[:, mb_id] =  Gxi_mb
        Gx_avg   = np.mean(Gop_memory, axis=1)
        x_prv    = x_cur
        
        # update the main iterates
        y_cur   = x_cur - eta*Sx_cur + ((2.0*gamma - 1.0)/gamma)*(y_cur - x_cur)
        x_next  = J_op_eval(data, y_cur, gamma*eta)

        # compute the error and operator norm.
        error   = la.norm(x_next - x_cur)
        op_norm = la.norm((x_cur - J_op_eval(data, x_cur - eta*Sx_cur, eta))/eta)
        
        # only evaluate G(x) for each epoch.
        if k%n_inner_iters==0:
            n_count += 1
            op_norm1 = la.norm((x_cur - J_op_eval(data, x_cur - eta*Gx_avg, eta))/eta)
            epoch_hist.append( dict({"epoch": n_count, "error": error, "op_norm": op_norm1/op_norm0}))

        # print every print_step iterations.
        if verbose:
            if k%n_inner_iters==0: #k % print_step == 0:
                print(
                    '{:^8.0f}'.format(int(n_count)), '|',
                    '{:^13.3e}'.format(error), '|',
                    '{:^13.3e}'.format(op_norm/op_norm0), '|'
                )
        # save history to plot results. 
        hist.append(dict({"iter":k, "error": error, "op_norm": op_norm}))

        # checking the termination conditions.
        if is_term and op_norm <= tol*max(op_norm0, 1.0):
            msg = "Convergence acheived!"
            break
            
        # move to the next iteration.
        x_cur  = x_next

    # end of the loop ...
    if verbose:
        print('{message:{fill}{align}{width}}'.format(message='', fill='=', align='^', width=42, ), '\n')
    if k+1 >= total_iters:
        msg = "Exceed the maximum number of epochs. Increase it to run further ..."
    return dict({"opt_sol": x_cur, "message": msg, "hist": hist, "epoch_hist": epoch_hist})


"""
Method: An Implementation of Variance-Reduced Forward-Reflected Method for solving the following inclusion: 
                                0 in G(x) + T(x).
            This is a single-loop implememtation (known as loopless SVRG).
    Inputs: + data = training dataset
            + G_op_eval = the evaluation of operator G(x)
            + J_op_eval = evaluate the resolvant of eta*T
            + mb_size = mini-batch size
            + prob = probability to update snapshot point w.
            + kwargs = optional and control parameters.
    Outputs: - opt_sol = approximate solution
             - message = solver messages
             - hist = history of training process.
"""
def Vr_FRBSA(data, G_op_eval, J_op_eval, x0, mb_size=5, prob=0.05, **kwargs):

    # parameters
    eta      = kwargs.pop('eta', 0.5)
    n_epochs = kwargs.pop('n_epochs', 200)
    tol      = kwargs.pop('tol', 1e-8)
    is_term  = kwargs.pop('is_term', True)
    
    # print setup
    verbose    = kwargs.pop('verbose', None)
    print_step = kwargs.pop('print_step', 100)

    # initialization
    n, p    = data.get("n"), data.get("p")
    full_id = range(n)
    msg     = "Initialization"
    n_count = 0
    n_inner_iters  = int(n/mb_size)
    total_iters    = int(n_epochs*n_inner_iters)
    hist           = []
    
    # initialize iterate vectors
    w_cur, w_prv, x_cur = x0.copy(), x0.copy(), x0.copy()
    
    # print initial information
    if verbose:
        print('Solver: Variance-Reduced Forward-Reflected-Backward Splitting (Alacaoglu et al 2021) ...')
        print(
            '{message:{fill}{align}{width}}'.format(message='', fill='=', align='^', width=42, ), '\n',
            '{message:{fill}{align}{width}}'.format(message='Epoch', fill=' ', align='^', width=7, ), '|',
            '{message:{fill}{align}{width}}'.format(message='Error', fill=' ', align='^', width=13, ), '|',
            '{message:{fill}{align}{width}}'.format(message='||G_eta(x)||', fill=' ', align='^', width=13, ), '\n',
            '{message:{fill}{align}{width}}'.format(message='', fill='-', align='^', width=42)
        )

    # evalute the first full operator G(w) at the snapshot point w.
    Gw_cur     = G_op_eval(data, w_cur, full_id, n)
    op_norm    = la.norm((w_cur - J_op_eval(data, w_cur - eta*Gw_cur, eta))/eta)
    op_norm0   = op_norm
    op_norm1   = op_norm
    epoch_hist = [ dict({"epoch":0, "error": 0, "op_norm": op_norm/op_norm0}) ]

    # the main loop to loop over epochs.
    for k in range(total_iters):

        # sample a mini-batch
        mb_id   = random.sample(full_id, mb_size)
        
        # evaluate G at a snap-shot point w_cur.
        Gxi_cur = G_op_eval(data, x_cur, mb_id, mb_size)
        Gwi_prv = G_op_eval(data, w_prv, mb_id, mb_size)
        Sx_cur  = Gw_cur + Gxi_cur - Gwi_prv

        # update the iterate.
        x_next  = J_op_eval(data, x_cur - eta*Sx_cur, eta)

        # compute the error and norm of gradient mapping.
        error   = la.norm(x_next - x_cur)
        op_norm = la.norm((x_cur - J_op_eval(data, x_cur - eta*Sx_cur, eta))/eta)

        # evaluate full operator at snapshot point.
        if np.random.binomial(n=1, p=prob):
            # return to snapshot points
            w_prv    = w_cur
            w_cur    = x_next
            Gw_cur   = G_op_eval(data, w_cur, full_id, n)
            op_norm1 = la.norm((w_cur - J_op_eval(data, w_cur - eta*Gw_cur, eta))/eta)

        # only evaluate G(x) for each epoch.
        if k%n_inner_iters==0:
            n_count += 1
            epoch_hist.append(dict({"epoch": n_count, "error": error, "op_norm": op_norm1/op_norm0}))

        # print every print_step iterations.
        if verbose:
            if k%n_inner_iters==0: #k % print_step == 0:
                print(
                    '{:^8.0f}'.format(int(n_count)), '|',
                    '{:^13.3e}'.format(error), '|',
                    '{:^13.3e}'.format(op_norm/op_norm0), '|'
                )
        # save history to plot results. 
        hist.append(dict({"epoch": n_count, "iter": k, "error": error, "op_norm": op_norm}))

        # checking the termination conditions.
        if is_term and op_norm <= tol*max(op_norm0, 1.0):
            msg = "Convergence acheived!"
            break
            
        # move to the next iteration.
        x_cur  = x_next

    # end of the loop ...
    if verbose:
        print('{message:{fill}{align}{width}}'.format(message='', fill='=', align='^', width=42, ), '\n')
    if k+1 >= total_iters:
        msg = "Exceed the maximum number of epochs. Increase it to run further ..."
        
    return dict({"opt_sol": x_cur, "message": msg, "hist": hist, "epoch_hist": epoch_hist})
    

"""
Method: Variance-Reduced Extragradient Method for Solving Monotone Inclusions of the form:
                                0 in G(x) + T(x).
            This is based on the paper of ALACAOGLU and MALITSKY (COLT, 2022).
    Inputs: + data = training dataset
            + G_op_eval = the evaluation of operator G(x)
            + J_op_eval = evaluate the resolvant of eta*T
            + mb_size = mini-batch size
            + prob = probability to update snapshot point w.
            + kwargs = optional and control parameters.
    Outputs: - opt_sol = approximate solution
             - message = solver messages
             - hist = history of training process.
"""
def Vr_EGA(data, G_op_eval, J_op_eval, x0, mb_size=5, alpha = 0.5, prob=0.05, **kwargs):

    # parameters
    eta      = kwargs.pop('eta', 0.5)
    alpha    = kwargs.pop('alpha', 0.5)
    n_epochs = kwargs.pop('n_epochs', 200)
    tol      = kwargs.pop('tol', 1e-8)
    is_term  = kwargs.pop('is_term', True)
    
    # print setup
    verbose    = kwargs.pop('verbose', None)
    print_step = kwargs.pop('print_step', 100)

    # initialization
    n, p          = data.get("n"), data.get("p")
    full_id       = range(n)
    msg           = "Initialization"
    n_inner_iters = int(n/mb_size)
    total_iters   = int(n_epochs*n_inner_iters)
    n_count       = 0
    hist          = []
    
    # initalize iterate vectors.
    w_cur, x_cur = x0.copy(), x0.copy()
    
    # print initial information
    if verbose:
        print('Solver: Variance-Reduced Extragradient Method (Alacaoglu & Malitsky 2022) ...')
        print(
            '{message:{fill}{align}{width}}'.format(message='', fill='=', align='^', width=42, ), '\n',
            '{message:{fill}{align}{width}}'.format(message='Epoch', fill=' ', align='^', width=7, ), '|',
            '{message:{fill}{align}{width}}'.format(message='Error', fill=' ', align='^', width=13, ), '|',
            '{message:{fill}{align}{width}}'.format(message='||G_eta(x)||', fill=' ', align='^', width=13, ), '\n',
            '{message:{fill}{align}{width}}'.format(message='', fill='-', align='^', width=42)
        )

    # evalute the first full operator G(w) at the snapshot point w.
    Gw_cur     = G_op_eval(data, w_cur, full_id, n)
    op_norm    = la.norm((w_cur - J_op_eval(data, w_cur - eta*Gw_cur, eta))/eta)
    op_norm0   = op_norm
    op_norm1   = op_norm
    epoch_hist = [ dict({"epoch":0, "error": 0, "op_norm": op_norm/op_norm0}) ]

    # the main loop to loop over epochs.
    for k in range(total_iters):
        
        # update xbar_cur and z_cur
        xbar_cur = (1.0-alpha)*w_cur + alpha*x_cur
        z_cur    = J_op_eval(data, xbar_cur - eta*Gw_cur, eta)

        # sample a mini-batch
        mb_id    = random.sample(full_id, mb_size)
        
        # evaluate G at a snap-shot point w_cur.
        Gzi_cur  = G_op_eval(data, z_cur, mb_id, mb_size)
        Gwi_cur  = G_op_eval(data, w_cur, mb_id, mb_size)
        Sx_cur   = Gw_cur + Gzi_cur - Gwi_cur

        # update the iterate.
        x_next   = J_op_eval(data, xbar_cur - eta*Sx_cur, eta)

        # compute the error and norm of gradient mapping
        error    = la.norm(x_next - x_cur)
        op_norm  = la.norm((x_cur - J_op_eval(data, x_cur - eta*Sx_cur, eta))/eta)

        # evaluate full operator at snapshot point.
        if np.random.binomial(n=1, p=prob):
            # return to snapshot points
            w_cur    = x_next
            Gw_cur   = G_op_eval(data, w_cur, full_id, n)
            op_norm1 = la.norm((w_cur - J_op_eval(data, w_cur - eta*Gw_cur, eta))/eta)

        # only evaluate G(x) for each epoch.
        if k%n_inner_iters==0:
            n_count +=1
            epoch_hist.append( dict({"epoch": n_count, "error": error, "op_norm": op_norm1/op_norm0}))

        # print every print_step iterations.
        if verbose:
            if k%n_inner_iters==0: #k % print_step == 0:
                print(
                    '{:^8.0f}'.format(int(n_count)), '|',
                    '{:^13.3e}'.format(error), '|',
                    '{:^13.3e}'.format(op_norm/op_norm0), '|'
                )
        # save history to plot results. 
        hist.append(dict({"epoch": n_count, "iter": k, "error": error, "op_norm": op_norm}))

        # checking the termination conditions.
        if is_term and op_norm <= tol*max(op_norm0, 1.0):
            msg = "Convergence acheived!"
            break
            
        # move to the next iteration.
        x_cur  = x_next

    # end of the loop ...
    if verbose:
        print('{message:{fill}{align}{width}}'.format(message='', fill='=', align='^', width=42, ), '\n')
    if k+1 >= total_iters:
        msg = "Exceed the maximum number of epochs. Increase it to run further ..."
        
    return dict({"opt_sol": x_cur, "message": msg, "hist": hist, "epoch_hist": epoch_hist})
    
