from acs.auxiliaries import *

"""
this file contains different optimization methods
"""

def proximal_iht(Phi, y, variance, k, alpha, beta, max_iter=800, tol=1e-5, verbose=True, reg_type='one'):
    """
    optimization objective:
        min_w   f1(w) + alpha * f2(w) + beta * f3(w) s.t.  ||w||_0 <= k, w >= 0
        where   f1(w) = ||Phi * w - y||^2  is approximation loss
                f2(w) = - <variance, 1(w>0)>   is variance loss
                if reg_type is 'one',  f3(w) = ||w - 1||^2   to control each w_i to be close to 1
                if reg_type is 'mean' f3(w) = ||w - mean(w)||^2  to control each w_i won't be too far apart
                alpha and beta are parameters to balance the three objectives
                Note: [f1(w) + alpha * f2(w)] is the real objective, since f3 is only a regularizer
    procedure:
        iterative hard thresholding with the proximal operator
    :param: y: torch tensor of shape (M, 1)
    :param: Phi: M*N torch tensor of shape (M, N)
    :param: variance: torch tensor of shape (N,)
    :param: alpha: positive float value
    :param: beta: positive float value
    """
    def prox(w, k, alpha, variance):
        """
        prox(w) = argmin_x  -alpha * <variance, 1(x>0)> + 2 * ||x-w||_2^2     s.t. ||x||_0 <=k and x>=0
        """
        N = w.shape[0]
        mask = w.reshape([-1]).sign()
        mask[mask < 0] = -1e10 # make sure the negative entries won't be selected
        indices = (0.5 * w.reshape([-1]).pow(2) * mask + alpha * variance).argsort()
        supp_complement = indices[:N - k]
        w[supp_complement] = 0.
        w[w < 0] = 0
        supp = torch.nonzero(w.reshape([-1])).reshape([-1]).tolist()
        return w, supp

    M, N = Phi.shape
    device = Phi.device
    dtype = Phi.dtype
    print_interval = int(max_iter / 10)
    w = torch.zeros([N, 1], dtype=dtype, device=device)
    v = torch.zeros([N, 1], dtype=dtype, device=device)
    for iteration in range(max_iter):
        w_prev = w

        # proximal gradient step
        grad_f1 = square_loss_gradient(Phi, y, v)
        grad_reg = regularizer_gradient(v, supp='full', beta=beta, reg_type=reg_type)
        grad_f1 = grad_f1 + grad_reg
        step_size = step_line_search_reg(Phi, y, v, grad_f1, k, beta, reg_type=reg_type)
        w, supp = prox(v - step_size * grad_f1, k, alpha, variance)

        # debias step
        grad_f = square_loss_gradient(Phi, y, w)
        grad_reg = regularizer_gradient(w, supp, beta, reg_type=reg_type)
        grad_f_supp = grad_f[supp] + grad_reg
        step_size = step_line_search_reg(Phi[:, supp], y, w[supp], grad_f_supp, k, beta, reg_type=reg_type)
        w[supp] -= step_size * grad_f_supp
        w, supp = projection_to_sparse_nonnegative(w, k, already_k_sparse=True)

        # momentum
        momentum = step_line_search(Phi, y, w, w - w_prev)
        v = w - momentum * (w - w_prev)

        # check convergence and report
        if torch.norm(w - w_prev) / (torch.norm(w) + 0.5) < tol:
            break
        if verbose and iteration % print_interval == 0:
            f, f1, f2 = obj(y, Phi, variance, alpha, w, supp)
            print('at iteration {}:'.format(iteration))
            print('training objective (f1 + alpha * f2) is {}, approximation loss (f1) is {}, selected variance loss '
                  '(alpha * f2) is {}, selected original variance (f2) is {}'.format(f, f1, f2, f2 / (alpha+1e-30)))

    return w, supp


def greedy(Phi, y, variance, k, alpha, sigma, L, beta, max_vertex_selection=2, verbose=True, reg_type='one'):
    """
    optimization objective:
        min_w   f1(w) + alpha * f2(w) + beta * f3(w) s.t.  ||w||_0 <= k, w >= 0
        where   f1(w) = ||Phi * w - y||^2  is approximation loss
                f2(w) = - <variance, 1(w>0)>   is variance loss
                if reg_type is 'one',  f3(w) = ||w - 1||^2   to control each w_i to be close to 1
                if reg_type is 'mean' f3(w) = ||w - mean(w)||^2  to control each w_i won't be too far apart
                alpha and beta are parameters to balance the three objectives
                Note: [f1(w) + alpha * f2(w)] is the real objective, since f3 is only a regularizer
    procedure:
        greedy update on the relaxation problem
            min_w   f1(w) + alpha * f2(w) + beta * f3(w)  s.t.  w >= 0   and   <w, sigma> <= L
        each vertices can not be selected more than max_vertex_selection times.
        debias step is conducted after each greedy update
    :param: y: torch tensor of shape (M, 1)
    :param: Phi: M*N torch tensor of shape (M, N)
    :param: variance: torch tensor of shape (N,)
    :param: alpha: positive float value
    :param: sigma: positive torch tensor of shape (N,)
    :param: L: positive float value
    :param: max_vertex_selection: an integer number
    """
    def select_vertex(grad, alpha_variance, vertex_count, unavaliable_vertices, w):
        if reg_type == 'mean':
            u = grad.reshape([-1]) - alpha_variance  # default step size 1
        elif reg_type == 'one':
            step_size = 1 - w.reshape([-1])
            u = grad.reshape([-1]) * step_size - alpha_variance
        else:
            raise ValueError
        u[unavaliable_vertices] = np.inf
        vertex = u.argmin().item()
        vertex_count[vertex] += 1
        if vertex_count[vertex] >= max_vertex_selection:
            # this vertex has been used too many times
            unavaliable_vertices.append(vertex)
        return vertex

    M, N = Phi.shape
    device = Phi.device
    dtype = Phi.dtype
    max_iter = max_vertex_selection * k
    print_interval = int(max_iter / 10)
    vertex_count = [0] * N
    unavaliable_vertices = []
    w = torch.zeros([N, 1], dtype=dtype, device=device)
    alpha_variance = alpha * variance
    vertices_scale = 1 / sigma #L / sigma
    scaled_Phi = vertices_scale * Phi

    for iteration in range(max_iter):
        # greedy update step
        grad_f1 = square_loss_gradient(Phi, y, w)
        grad_reg = regularizer_gradient(w, 'full', beta, reg_type=reg_type)
        grad_f = grad_f1 + grad_reg
        vertex = select_vertex(grad_f, alpha_variance, vertex_count, unavaliable_vertices, w)

        supp = torch.nonzero(w.reshape([-1])).reshape([-1]).tolist()
        if vertex not in supp:
            supp.append(vertex)
            vertex_position = len(supp) - 1
        else:
            vertex_position = supp.index(vertex)
        v = torch.zeros([len(supp), 1], dtype=dtype, device=device)
        v[vertex_position] = 1
        step_size = step_line_search_reg(Phi[:, supp], y, w[supp], v, k, beta, reg_type=reg_type)
        w[supp] -= step_size * v

        # debias step
        grad_f = square_loss_gradient(Phi, y, w)
        grad_reg = regularizer_gradient(w, supp, beta, reg_type=reg_type)
        grad_f_supp = grad_f[supp] + grad_reg
        step_size = step_line_search_reg(Phi[:, supp], y, w[supp], grad_f_supp, k, beta, reg_type=reg_type)
        w[supp] -= step_size * grad_f_supp
        w, supp = projection_to_sparse_nonnegative(w, k, already_k_sparse=True)

        # check convergence and report
        if verbose and iteration % print_interval == 0:
            f, f1, f2 = obj(y, Phi, variance, alpha, w, supp)
            print('at iteration {}:'.format(iteration))
            print('training objective (f1 + alpha * f2) is {}, approximation loss (f1) is {}, selected variance loss '
                  '(alpha * f2) is {}, selected original variance (f2) is {}'.format(f, f1, f2, f2 / (alpha+1e-30)))
        if len(supp) >= k:
            break
    return w, supp
