
import numpy as np
from .projections import euclidean_proj_simplex

def compute_strong_gap(operator, x):
    v = operator(x)
    return np.max(v) - np.dot(x, v)

def solve_vi(K, operator, iterations, eta, return_path=False):
    x = np.ones(K) / K
    epsilon = 0
    epsilons = []
    return_vals = [operator(x)]
    x_path = [x]
    
    for i in range(iterations):
        x2 = euclidean_proj_simplex( x + eta * operator(x))
        epsilon = np.abs((x - x2)).sum()
        x = x2
        
        epsilons.append(epsilon)
        return_vals.append(operator(x))
        x_path.append(x)
    
    if return_path:
        return x, epsilons, return_vals, x_path
    return x