import torch
import numpy as np

    
def find_min_norm_element(vecs):

	# Generate composite weights via vanilla min-norm

    v1, v2 = vecs
    v1v1 = np.dot(v1, v1)
    v1v2 = np.dot(v1, v2)
    v2v2 = np.dot(v2, v2)
        
    if v1v2 >= v1v1:
        # Case: Fig 1, third column
        gamma = 0.999
    if v1v2 >= v2v2:
        # Case: Fig 1, first column
        gamma = 0.001
    else:
        # Case: Fig 1, second column
        gamma = -1.0 * ( (v1v2 - v2v2) / (v1v1+v2v2 - 2*v1v2) )

    sol_vec = np.zeros(2)
    sol_vec[0], sol_vec[1] = gamma, 1-gamma
    
    return sol_vec

def find_min_norm_element_l1(vecs, gamma0, alpha):

	# Generate composite weights via min-regularized-norm with L1-regularization

    v1v1, v1v2, v2v2 = .0, .0, .0
    
    v1, v2 = vecs
    v1v1 = np.dot(v1, v1)
    v1v2 = np.dot(v1, v2)
    v2v2 = np.dot(v2, v2)
        
    gammaL = ((v2v2-v1v2)+alpha)/((v1v1+v2v2-v1v2*2)+1e-6)
    gammaR = ((v2v2-v1v2)-alpha)/((v1v1+v2v2-v1v2*2)+1e-6)
    
    if gammaL < gamma0:
        gamma = max(gammaL, .001)
    elif gammaR > gamma0:
        gamma = min(gammaR, .999)
    else:
        gamma = gamma0
    sol_vec = np.zeros(2)
    sol_vec[0] = gamma
    sol_vec[1] = 1-gamma
    
    return sol_vec

def find_min_norm_element_l2(vecs, gamma0, alpha):

	# Generate composite weights via min-regularized-norm with L2-regularization

    v1v1, v1v2, v2v2 = .0, .0, .0
    
    v1, v2 = vecs
    v1v1 = np.dot(v1, v1)
    v1v2 = np.dot(v1, v2)
    v2v2 = np.dot(v2, v2)
        
    gamma = ((v2v2-v1v2)+alpha*gamma0)/((v1v1+v2v2-v1v2*2)+alpha)
    gamma = np.clip(gamma, .0, 1.0)
    sol_vec = np.zeros(2)
    sol_vec[0] = gamma
    sol_vec[1] = 1-gamma
    
    return sol_vec
