import torch 

from objective_func import obj_F_square, obj_F_square_vec
from lemke import lemkelcp_torch
from Bsemi_func import *


def Z(x):
    return - x / (2 * torch.sqrt(x**2 + 4))

# ==================================================================================
# 1. Proximal gradient Algorithm for ACCORD
# ==================================================================================
def proxg_update(X, omega, d, diag_ind, lamb, tau_init, update_method, device):
    """
    Proximal Gradient Descent algorithm in ACCORD paper
    """
    n, _ = X.shape
    XT = X.T  
    nrow = torch.arange(omega.shape[0], device=device)  
    
    omega_old = omega.clone() # Clone the initial matrix
    Y = omega_old @ XT 

    # Calculate the initial value of the objective function
    org_F_square = obj_F_square(X, omega, d, lamb, diag_ind).item()

    while True:
        g_old = 0.5 * torch.norm(Y, p='fro')**2 / n # Calculate the current objective function value
    
        tau = tau_init # Initialize step size
        grad = (Y @ X)/n # Compute the gradient
    
        for j in range(20):
            # Forward step: Update the matrix o_tilde
            o_tilde = omega_old - tau * grad
    
            # Backward step: Extract diagonal elements and update them
            omega_d = o_tilde[nrow, diag_ind] 
            omega_d = (omega_d + torch.sqrt(omega_d ** 2 + 4.0 * tau)) * 0.5 
            
            omega_new = torch.nn.functional.softshrink(o_tilde, lamb * tau)
            omega_new[nrow, diag_ind] = omega_d # Update diagonal elements
    
            # Calculate the new objective function value
            Y = omega_new @ XT
            g_new = 0.5 * torch.norm(Y, p="fro")**2 / n 
    
            # Check convergence condition
            D = omega_new - omega_old
            Q = g_old + torch.sum(D * grad) + torch.norm(D, p='fro')**2 / (2.0 * tau)
    
            if g_new < Q:
                if update_method == 'B-semi' :
                    # Calculate the dual variable
                    temp = torch.matmul(omega_new, X.T)
                    di = - torch.matmul(temp, X) / n

                    new_F_square = obj_F_square(X, omega_new, di, lamb, diag_ind).item()

                    if new_F_square < org_F_square:
                        return omega_new 
                
                else : return omega_new 
            tau *= 0.6 # Reduce the step size
        omega_old, omega_new = omega_new, omega_old 

def proxg_update_row(X, omega, d, diag_ind, lamb, tau_init, update_method, device):
    """
    Proximal Gradient Descent algorithm utilizing row-separability
    """
    n, _ = X.shape
    XT = X.T  
    nrow = torch.arange(omega.shape[0], device=device)  
    
    omega_old = omega.clone() # Clone the initial matrix
    Y = omega_old @ XT 

    # Calculate the initial value of the objective function
    org_F_square = obj_F_square_vec(X, omega, d, lamb, diag_ind).item()

    if org_F_square < 1e-10 :
        return omega_old, True
    else : 
        for i in range(5):
            g_old = 0.5 * torch.norm(Y, p='fro')**2 / n # Calculate the current objective function value
        
            tau = tau_init # Initialize step size
            grad = (Y @ X)/n # Compute the gradient
        
            for j in range(20):
                # Forward step: Update the matrix o_tilde
                o_tilde = omega_old - tau * grad
        
                # Backward step: Extract diagonal elements and update them
                omega_d = o_tilde[diag_ind] 
                omega_d = (omega_d + torch.sqrt(omega_d ** 2 + 4.0 * tau)) * 0.5 
                
                omega_new = torch.nn.functional.softshrink(o_tilde, lamb * tau)
                omega_new[diag_ind] = omega_d # Update diagonal elements
        
                # Calculate the new objective function value
                Y = omega_new @ XT
                g_new = 0.5 * torch.norm(Y, p="fro")**2 / n 
        
                # Check convergence condition
                D = omega_new - omega_old
                Q = g_old + torch.sum(D * grad) + torch.norm(D, p='fro')**2 / (2.0 * tau)
        
                if g_new < Q:
                    if update_method == 'B-semi' :
                        # Calculate the dual variable
                        temp = torch.matmul(omega_new, X.T)
                        di = - torch.matmul(temp, X) / n

                        new_F_square = obj_F_square_vec(X, omega_new, di, lamb, diag_ind).item()

                        if new_F_square < org_F_square:
                            return omega_new, False
                    
                    else : return omega_new, False
                tau *= 0.6 # Reduce the step size
            omega_old, omega_new = omega_new, omega_old 
        return omega_new, False

# ==================================================================================
# Semismooth Algorithm for ACCORD
# ==================================================================================
def newt_update(X, omega_diag, d, activation, diag_p, lamb, device):
    """
    Semismooth newton algorithm 
    """
    n, p = X.shape
    next_omega = torch.zeros_like(d) 
    next_d = torch.zeros_like(d) 
    all_ind = torch.arange(p, device = device) 

    # Calcualte active set and inactive set
    pos_set = all_ind[activation > 0]
    neg_set = all_ind[activation < 0]
    active_set = torch.cat((pos_set, neg_set))
    inactive_set = all_ind[(activation[all_ind] == 0) & (all_ind != diag_p)]

    omega_d = omega_diag + d[diag_p]
    zk = Z(omega_d) 
    bk = -(zk + 1/2)/(zk - 1/2)
    ck = (zk * omega_d + torch.sqrt(omega_d ** 2 + 4)/2)/(zk - 1/2)
    dk_active = torch.cat((torch.ones(len(pos_set), dtype = X.dtype, device = device) * lamb, torch.ones(len(neg_set), dtype = X.dtype, device = device) * (-lamb)))

    # generate submatrix s_active
    indices = torch.cat((diag_p, active_set))
    s_active = X[:, indices].T @ X[:, indices] / n
    s_active[0,0] += bk[0]
    
    # Solve linear system Ax = b
    omega_solved = torch.linalg.solve(s_active, -torch.cat([ck, dk_active]))
    temp = torch.matmul(X[:, active_set], omega_solved[1:])
    d_solved = - torch.matmul(X[:, inactive_set].T, X[:, diag_p].flatten() * omega_solved[0] + temp) / n

    # Update search direction
    next_omega[active_set] = omega_solved[1:]
    next_omega[diag_p] = omega_solved[0]
    next_d[active_set] = dk_active
    next_d[diag_p] = bk * omega_solved[0] + ck
    next_d[inactive_set] = d_solved
    
    return next_omega, next_d

# ==================================================================================
# 2. Semismooth Algorithm for ACCORD with global convergence
# ==================================================================================
def semi_newt_update(X, omega, d, activation, diag_p, lamb, rho, device, max_iter_line = 128, armijo_const = 0.001):
    """
    Semismooth newton algorithm using backtracking line search for global convergence
    """
    theta_k = obj_F_square_vec(X, omega, d, lamb, diag_p)

    if theta_k < 1e-10 :
        # Update of each row is terminated 
        return omega, True
    else : 
        # Semismooth algorithm
        next_omega, next_d = newt_update(X, omega[diag_p], d, activation, diag_p, lamb, device)

        # Compute search direction
        v_k_omega = next_omega - omega
        v_k_d = next_d - d

        # Backtracking line search
        m_k = 0
        rho_mk = 0
        base_tensor = torch.arange(64, device=device)

        while m_k < max_iter_line:
            m_ks = m_k + base_tensor
            rho_mks = rho ** m_ks

            theta_new = obj_F_square_vec(X, omega + rho_mks[:, None] * v_k_omega, d + rho_mks[:, None] * v_k_d, lamb, diag_p)
            lhs = theta_new - theta_k
            rhs = -2.0 * armijo_const * rho_mks * theta_k
    
            indices = torch.where(lhs <= rhs)[0]
            if indices.numel() > 0:
                rho_mk = rho_mks[indices[0]]
                break
                
            m_k += 64

        # Final update
        if rho_mk > 0 : 
            update_omega = omega + rho_mk * v_k_omega
        else :
            update_omega = omega + rho ** max_iter_line * v_k_omega
        
        return update_omega, False

# ==================================================================================
# 3. Bouligand-Semismooth Algorithm for ACCORD with global convergence
# ==================================================================================
def bouligand_newt_update(X, omega, d, activation, diag_p, lamb, rho, device):
    """
    Bouligand semismooth newton algorithm from Bsemi_func.py
    """
    # 1) Calculate active set and inactive set (+ semi active set)
    pos_set, semi_pos_set, neg_set, semi_neg_set, active_set, inactive_set = get_active_inactive_sets(activation, diag_p, device)
    
    # 2) Compute constant
    bk, ck = compute_bc(omega, d, diag_p)
    
    # 3) Set initial search direction using given equation
    v_k_omega, v_k_d = initialize_direction_vectors(omega, d, pos_set, neg_set, active_set, inactive_set, lamb, device)

    # 4) Generate matrix for LCP
    P, inv_N, M, q = generate_matrix(X, omega, d, diag_p, pos_set, semi_pos_set, neg_set, semi_neg_set, active_set, inactive_set, bk, ck, v_k_omega, v_k_d, device)

    # 5) LCP
    sol = lemkelcp_torch(M, q, maxIter=1000, device='cuda')

    # 6. Update search direction
    v_k_omega, v_k_d = update_search_direction(sol[0], omega, d, diag_p, pos_set, semi_pos_set, neg_set, semi_neg_set, active_set, inactive_set, v_k_omega, v_k_d, P, inv_N, device)
    v_k_d[diag_p] = bk * v_k_omega[diag_p] + ck
    
    # 6. Line search
    step_size = line_search(X, omega, d, v_k_omega, v_k_d, lamb, diag_p, rho, device)
    
    # 7. Update primal value
    next_omega = omega + step_size * v_k_omega
    
    return next_omega