import numpy as np

class MinNormSolver:
    
    '''
    This is a simple realization of the min-norm solver in MOO, where there are only TWO tasks.
    '''
    
    def cal_min_norm(vecs):
        v1, v2 = vecs
        v1v1 = np.dot(v1, v1)
        v1v2 = np.dot(v1, v2)
        v2v2 = np.dot(v2, v2)
            
        if v1v2 >= v1v1:
            # Case: Fig 1, third column
            gamma = 0.999
        if v1v2 >= v2v2:
            # Case: Fig 1, first column
            gamma = 0.001
        else:
            # Case: Fig 1, second column
            gamma = -1.0 * ( (v1v2 - v2v2) / (v1v1+v2v2 - 2*v1v2) )
        
        v = gamma*v1 + (1-gamma)*v2
        return np.dot(v, v)
    
    def find_min_norm_element(vecs):
        '''
        - INPUT
        vecs = [task_1_grad, task_2_grad], where task_i_grad = [grad_w1, grad_b1, ..., grad_wn, grad_bn] 
        Each grad_xx is a tensor in torch. Their shapes are inconsistent

        - OUTPUT
        sol_vec = [task_1_weight, task_2_weight]
        '''
        
        # Calculate the inner products of gradients
        # vi = task_i_grad
        v1, v2 = vecs
        v1v1 = np.dot(v1, v1)
        v1v2 = np.dot(v1, v2)
        v2v2 = np.dot(v2, v2)
            
        if v1v2 >= v1v1:
            # Case: Fig 1, third column
            #gamma = 0.999
            gamma = 1.0
        if v1v2 >= v2v2:
            # Case: Fig 1, first column
            #gamma = 0.001
            gamma = 0.0
        else:
            # Case: Fig 1, second column
            gamma = -1.0 * ( (v1v2 - v2v2) / (v1v1+v2v2 - 2*v1v2) )
    
        sol_vec = np.zeros(2)
        sol_vec[0], sol_vec[1] = gamma, 1-gamma
        
        return sol_vec

    def find_min_norm_element_l2(vecs, gamma_prev, beta):
        v1v1, v1v2, v2v2 = .0, .0, .0
        
        v1, v2 = vecs
        v1v1 = np.dot(v1, v1)
        v1v2 = np.dot(v1, v2)
        v2v2 = np.dot(v2, v2)
            
        gamma = (beta*(v2v2-v1v2)+(1-beta)*gamma_prev)/(beta*(v1v1+v2v2-v1v2*2)+(1-beta))
        #print(gamma.cpu())
        #gamma = np.clip(gamma.cpu().numpy(), .001, .999)
        gamma = np.clip(gamma, .0, 1.0)
        sol_vec = np.zeros(2)
        sol_vec[0] = gamma
        sol_vec[1] = 1-gamma
        
        return sol_vec
    
    def find_min_norm_element_l1(vecs, gamma_prev, beta):
        v1v1, v1v2, v2v2 = .0, .0, .0
        
        v1, v2 = vecs
        v1v1 = np.dot(v1, v1)
        v1v2 = np.dot(v1, v2)
        v2v2 = np.dot(v2, v2)
            
        gammaL = (beta*(v2v2-v1v2)+(1-beta))/(beta*(v1v1+v2v2-v1v2*2)+1e-6)
        gammaR = (beta*(v2v2-v1v2)-(1-beta))/(beta*(v1v1+v2v2-v1v2*2)+1e-6)
        
        if gammaL < gamma_prev:
            gamma = max(gammaL, .001)
        elif gammaR > gamma_prev:
            gamma = min(gammaR, .999)
        else:
            gamma = gamma_prev
        sol_vec = np.zeros(2)
        sol_vec[0] = gamma
        sol_vec[1] = 1-gamma
        
        return sol_vec
    
    def find_min_norm_element_l1_v2(vecs, gamma_prev, beta):
        v1v1, v1v2, v2v2 = .0, .0, .0
        
        v1, v2 = vecs
        v1v1 = np.dot(v1, v1)
        v1v2 = np.dot(v1, v2)
        v2v2 = np.dot(v2, v2)
            
        gammaL = (beta*(v2v2-v1v2)+(1-beta))/(beta*(v1v1+v2v2-v1v2*2)+1e-6)
        gammaR = (beta*(v2v2-v1v2)-(1-beta))/(beta*(v1v1+v2v2-v1v2*2)+1e-6)
        
        if gammaL < gamma_prev:
            gamma = max(gammaL, .0)
        elif gammaR > gamma_prev:
            gamma = min(gammaR, 1.0)
        else:
            gamma = gamma_prev
        sol_vec = np.zeros(2)
        sol_vec[0] = gamma
        sol_vec[1] = 1-gamma
        
        return sol_vec