import torch
from objective_func import obj_F_square_vec

"""
Functions for Bsemismooth algorithm
"""

def Z(x):
    return - x / (2 * torch.sqrt(x**2 + 4))
    
def get_active_inactive_sets(activation, diag_p, device):
    all_ind = torch.arange(activation.shape[0], device=device)
    
    pos_set = all_ind[activation == 1]
    neg_set = all_ind[activation == -1]
    semi_pos_set = all_ind[activation == 0.5]
    semi_neg_set = all_ind[activation == -0.5]
    
    active_set = torch.cat((pos_set, neg_set))
    
    inactive_set = all_ind[(activation[all_ind] == 0) & (all_ind != diag_p)]
    return pos_set, semi_pos_set, neg_set, semi_neg_set, active_set, inactive_set


def compute_bc(omega, d, diag_p):
    diag = omega[diag_p] + d[diag_p]
    zk = Z(diag)
    bk = -(zk + 1/2) / (zk - 1/2)
    ck = -(omega[diag_p] - (diag + torch.sqrt(diag**2 + 4)) / 2) / (zk - 1/2)
    return bk, ck


def initialize_direction_vectors(omega, d, pos_set, neg_set, active_set, inactive_set, lamb, device):
    v_k_omega = torch.zeros_like(omega)
    v_k_d = torch.zeros_like(d)
    
    v_k_omega[inactive_set] = -omega[inactive_set]
    
    temp = torch.cat((
        torch.ones(len(pos_set), device=device) * lamb,
        torch.ones(len(neg_set), device=device) * (-lamb)
    ))
    v_k_d[active_set] = temp - d[active_set]
    
    return v_k_omega, v_k_d

def generate_matrix(X, omega, d, diag_p, pos_set, semi_pos_set, neg_set, semi_neg_set, active_set, inactive_set, bk, ck, v_k_omega, v_k_d, device):
    # P
    n,_ = X.shape
    P = - (torch.matmul(torch.matmul(omega, X.T), X) + torch.matmul(X.T , torch.matmul(X[:, inactive_set], v_k_omega[inactive_set]))) / n - d
    P[diag_p] -= ck[0]
    P[active_set] -= v_k_d[active_set]
    P[semi_neg_set] *= -1 
    P = P[torch.cat((diag_p, active_set, semi_pos_set, semi_neg_set, inactive_set))]
    
    # N_1
    indices = torch.cat((diag_p, active_set, semi_pos_set, semi_neg_set))
    N_1 = X[:, indices].T @ X[:, indices] / n
    N_1[0, 0] += bk[0]
    if len(semi_neg_set) > 0:
        N_1[-len(semi_neg_set):, :] *= -1
        N_1[:, -len(semi_neg_set):] *= -1

    # inverse of N_1
    inv_N_1 = torch.linalg.inv(N_1)

    # v_next
    total_pos = len(diag_p) + len(active_set) + len(semi_pos_set)
    total_neg = len(semi_neg_set)
    scale = torch.empty(len(indices), device=device)
    scale[:total_pos].fill_(1.0)
    if total_neg > 0:
        scale[total_pos:].fill_(-1.0)
        
    temp = torch.matmul(inv_N_1, P[:N_1.shape[0]])
    temp_inv_N_1 = torch.matmul((X[:, indices] * scale.unsqueeze(0)) / n, temp)

    v_next = torch.cat((
        torch.matmul(inv_N_1, P[:N_1.shape[0]]),
        -torch.matmul(X[:, inactive_set].T, temp_inv_N_1) + P[N_1.shape[0]:]
    ), dim=0)
    
    # M 
    M = inv_N_1[-(len(semi_pos_set)+len(semi_neg_set)):, -(len(semi_pos_set)+len(semi_neg_set)):]

    # q
    cat_omega = torch.cat((omega[semi_pos_set], -omega[semi_neg_set]), dim=0)
    q = - M @ cat_omega + inv_N_1[len(active_set) + 1:len(active_set) + len(semi_pos_set) + len(semi_neg_set) + 1,:] @ P[0 : N_1.shape[0]] + cat_omega
    return P, v_next, M, q

def update_search_direction(sol, omega, d, diag_p, pos_set, semi_pos_set, neg_set, semi_neg_set, active_set, inactive_set, v_k_omega, v_k_d, P, v_next, device):
    v_k_d[semi_pos_set] = -sol[:len(semi_pos_set)] + omega[semi_pos_set]
    v_k_d[semi_neg_set] = sol[len(semi_pos_set):] + omega[semi_neg_set]
    
    P[len(active_set) + 1: len(active_set) + len(semi_pos_set) + 1] -= v_k_d[semi_pos_set]
    P[len(active_set) + len(semi_pos_set) + 1:
                                len(active_set) + len(semi_pos_set) + len(semi_neg_set) + 1] += v_k_d[semi_neg_set]
    
    v_k_omega[diag_p] = v_next[0]
    v_k_omega[active_set] = v_next[1:len(active_set) + 1]
    v_k_omega[semi_pos_set] = v_next[len(active_set) + 1: len(active_set) + len(semi_pos_set) + 1]
    v_k_omega[semi_neg_set] = -v_next[len(active_set) + len(semi_pos_set) + 1:
                                len(active_set) + len(semi_pos_set) + len(semi_neg_set) + 1]
    v_k_d[inactive_set] = v_next[len(active_set) + len(semi_pos_set) + len(semi_neg_set) + 1:]
    
    return v_k_omega, v_k_d

def line_search(X, omega, d, v_k_omega, v_k_d, lamb, diag_p, rho, device, max_iter_line = 128, armijo_const = 0.01):
    theta_k = obj_F_square_vec(X, omega, d, lamb, diag_p)
    m_k = 0
    rho_mk = 0
    base_tensor = torch.arange(64, device=device)

    while m_k < max_iter_line:
        m_ks = m_k + base_tensor
        rho_mks = rho ** m_ks

        theta_new = obj_F_square_vec(X, omega + rho_mks[:, None] * v_k_omega, d + rho_mks[:, None] * v_k_d, lamb, diag_p)
        lhs = theta_new - theta_k
        rhs = -2.0 * armijo_const * rho_mks * theta_k

        indices = torch.where(lhs <= rhs)[0]
        if indices.numel() > 0:
            rho_mk = rho_mks[indices[0]]
            break
            
        m_k += 64
    if rho_mk > 0 : 
        return rho_mk
    else :
        return rho ** max_iter_line 
