import numpy as np
import scipy.sparse as sp
import torch
import torch.optim as optim
from typing import Optional
from tqdm import tqdm


def get_partition(P:sp.csr_matrix) -> list:
    """
    Return for a Coarsening matrix P, the node assignation cluster
    args:
        P(sp.csr_matrix) : Coarsening matrix
    return:
    partition : list of np.ndarray
    """
    partition = []
    for i in range(P.shape[0]):
        cluster = P[i].nonzero()[1]
        partition.append(cluster)
    return partition



def minimize_rsa_support_l_normalized(mu_init:torch.tensor,L:sp.csr_matrix,R:sp.csr_matrix,P_init:sp.csr_matrix,
                Q:sp.csr_matrix,
                lr:Optional[float] = 0.01,
                  n_iter:Optional[int] = 100,
                    name_optim:Optional[str] = "Adam",
                    momentum:Optional[float] = 0.9,
                    keep_historic:Optional[bool] = False,
                    project_generalized_inverse:Optional[bool] = True,
                    degree_coarsened:Optional[np.ndarray] = None,
                    degree_original:Optional[np.ndarray] = None,
                    device:Optional[torch.device] = torch.device('cpu')
                    ) -> tuple[torch.tensor,list,torch.tensor, list, list, list]:
    """
    suitable for NORMALIZED
    maximize the weights of the coarsening matrix P for normalized laplacian by enforcing the support
    args:
        mu_init(torch.tensor) : initial weights of the coarsening matrix
        L(sp.csr_matrix) : Laplacian matrix of the graph
        R(sp.csr_matrix) : Matrix of the eigenvectors of the graph
        P_init(sp.csr_matrix) : Coarsening matrix
        Q(sp.csr_matrix) : Lifting matrix, helps to define the partition
        lr(float) : learning rate
        n_iter(int) : number of iterations
        name_optim(str) : name of the optimizer
        momentum(float) : momentum of the optimizer
    return:
        mu(torch.tensor) : final weights of the coarsening matrix
        loss_l(list) : list of the loss during the optimization
        P_torch_mu_final(torch.tensor) : final coarsening matrix
    """
    mu_historic = []
    loss_historic = []
    P_torch_mu_historic = []
    mu = mu_init.clone().detach().requires_grad_(True)
    if name_optim == "Adam":
        optimizer_mu = optim.Adam([mu],lr = lr)
    if name_optim == "SGD":
        optimizer_mu = optim.SGD([mu],lr = lr, momentum=momentum)
    if name_optim == "RMSprop":
        optimizer_mu = optim.RMSprop([mu],lr = lr)
    if name_optim == "RADAM":
        optimizer_mu = optim.RAdam([mu],lr = lr)
    partition_ground = get_partition(Q.T)
    L_torch = torch.tensor(L.toarray(),dtype=torch.float32, device=device)
    loss_l = []


    n,K = R.shape
    R_torch = torch.tensor(R.toarray(),dtype=torch.float32, device=device)
    eig_val, eig_vec = torch.linalg.eigh(L_torch)

    mask_zero_eig = eig_val < 1e-4
    eig_val[mask_zero_eig] = 0
    eig_val_sqrt = eig_val**0.5
    L_sqrt = eig_vec @ torch.diag(eig_val_sqrt) @ eig_vec.T
    eig_val[mask_zero_eig] = 1
    eig_val_inv_sqrt = eig_val**(-0.5)
    eig_val_inv_sqrt[mask_zero_eig] = 0
    R_dot_eig = R_torch @ torch.diag( eig_val_inv_sqrt[:K])
    Q_torch = torch.tensor(Q.toarray(),dtype=torch.float32, device=device)
    
    if degree_coarsened is not None:
        sqrt_degree_coarsened_mu = np.sqrt(degree_coarsened)
    if degree_original is not None:
        sqrt_degree_original_mu = np.sqrt(degree_original)
    for _ in tqdm(range(n_iter)):
        optimizer_mu.zero_grad()
        P_torch_mu = torch_P_according_mu(P_init,mu) 
        loss = compute_rsa_differentiable(P_torch_mu,Q_torch,L_sqrt,R_dot_eig)
        loss_l.append(loss.item())
        loss.backward()
        optimizer_mu.step()
        if project_generalized_inverse:
            if degree_coarsened is None:
                raise ValueError("degree_coarsened must be provided if project_generalized_inverse is True")
            if degree_original is None:
                raise ValueError("degree_original must be provided if project_generalized_inverse is True")
            with torch.no_grad():
                new_mu = project_mu_to_normalized_reflexive_support(mu,partition_ground, sqrt_degree_coarsened_mu, sqrt_degree_original_mu)
                mu = mu.copy_(new_mu)
        if keep_historic:
            mu_historic.append(mu.clone().detach().cpu().numpy())
            P_torch_mu_historic.append(P_torch_mu.clone().to_dense().detach().cpu().numpy())
            loss_historic.append(loss.detach().cpu().numpy())
    P_torch_mu_final = torch_P_according_mu(P_init,mu)
    return mu, loss_l, P_torch_mu_final, mu_historic, loss_historic, P_torch_mu_historic

def compute_rsa_differentiable(P:torch.sparse.FloatTensor,P_inv:torch.FloatTensor,
                               L_sqrt:torch.FloatTensor,R_dot:torch.FloatTensor) -> torch.FloatTensor:
    """
    Compute the exact RSA but in torch
    Args:
       
        R_is_L(bool): if True the R matrix is a subspace of laplacian eigenvectors  of the norm
        other_R(sp.csr_matrix): the R matrix to use
        R_dot(sp.csr_matrix): the R matrix @ 

    """
    Pi = P_inv @ P
    Pi_perp = torch.eye(P.shape[1], device=P.device) - Pi
    matrix_to_compute_norm = L_sqrt @ Pi_perp @ R_dot
    rsa = torch.linalg.matrix_norm(matrix_to_compute_norm, ord=2)
    return rsa



def torch_P_according_mu(P:sp.csr_matrix,mu:torch.tensor):
    """
    For a coarsening matrix P, replace the coefficient by the value of mu
    """
    row_indices,col_indices = P.nonzero()
    new_values = mu[col_indices] 
    indices = np.vstack((row_indices, col_indices))

    P_torch_mu = torch.sparse_coo_tensor(indices,values=new_values,size = P.shape,device=mu.device)
    return P_torch_mu


def project_mu_to_normalized_reflexive_support(mu:torch.tensor,partition:list,sqrt_degree_coarsened_mu, sqrt_degree_original_mu) -> torch.tensor:
    """
    Project  the vectors of weights mu to the space of degree normalized 
    args:
        mu : torch.tensor of shape (n_nodes,)
        partition : list of np.ndarray
    return:
        mu_projected : torch.tensor of shape (n_nodes,)
    """
    mu_projected = mu.clone()
    mu_projected = torch.clamp(mu_projected, min=0.0001)
    for i,cluster in enumerate(partition):
        weighted_sum = 0
        for element in cluster:
            
            weighted_sum += mu_projected[element] * sqrt_degree_original_mu[element] 
        if weighted_sum > 1e-8:
            mu_projected[cluster] = mu_projected[cluster] * sqrt_degree_coarsened_mu[i] /weighted_sum 
        else :
            print("Error the cluster has zero sum")
            print("mu[cluster]",mu[cluster])
    return mu_projected


def minimize_rsa_support(mu_init:torch.tensor,L:sp.csr_matrix,R:sp.csr_matrix,P_init:sp.csr_matrix,
                Q:sp.csr_matrix,
                lr:Optional[float] = 0.01,
                  n_iter:Optional[int] = 100,
                    name_optim:Optional[str] = "Adam",
                    momentum:Optional[float] = 0.9,
                    keep_historic:Optional[bool] = False,
                    device:Optional[torch.device] = torch.device('cpu')
                    ) -> tuple[torch.tensor,list,torch.tensor, list, list, list]:
    """
    
    maximize the weights of the coarsening matrix P but keeping the constraints of a well defined coarsening matrix
    and the non null indices of P
    args:
        mu_init(torch.tensor) : initial weights of the coarsening matrix
        L(sp.csr_matrix) : Laplacian matrix of the graph
        R(sp.csr_matrix) : Matrix of the eigenvectors of the graph
        P_init(sp.csr_matrix) : Coarsening matrix
        Q(sp.csr_matrix) : Lifting matrix, helps to define the partition
        lr(float) : learning rate
        n_iter(int) : number of iterations
        name_optim(str) : name of the optimizer
        momentum(float) : momentum of the optimizer
    return:
        mu(torch.tensor) : final weights of the coarsening matrix
        loss_l(list) : list of the loss during the optimization
        P_torch_mu_final(torch.tensor) : final coarsening matrix
    """
    mu_historic = []
    loss_historic = []
    P_torch_mu_historic = []
    mu = mu_init.clone().detach().requires_grad_(True)
    if name_optim == "Adam":
        optimizer_mu = optim.Adam([mu],lr = lr)
    if name_optim == "SGD":
        optimizer_mu = optim.SGD([mu],lr = lr, momentum=momentum)
    if name_optim == "RMSprop":
        optimizer_mu = optim.RMSprop([mu],lr = lr)
    if name_optim == "RADAM":
        optimizer_mu = optim.RAdam([mu],lr = lr)
    partition_ground = get_partition(Q.T)
    L_torch = torch.tensor(L.toarray(),dtype=torch.float32, device=device)
    loss_l = []
    n,K = R.shape
    R_torch = torch.tensor(R.toarray(),dtype=torch.float32, device=device)
    eig_val, eig_vec = torch.linalg.eigh(L_torch)
    mask_zero_eig = eig_val < 1e-4
    eig_val[mask_zero_eig] = 0
    eig_val_sqrt = eig_val**0.5
    L_sqrt = eig_vec @ torch.diag(eig_val_sqrt) @ eig_vec.T
    eig_val[mask_zero_eig] = 1
    eig_val_inv_sqrt = eig_val**(-0.5)
    eig_val_inv_sqrt[mask_zero_eig] = 0
    R_dot_eig = R_torch @ torch.diag( eig_val_inv_sqrt[:K])
    Q_torch = torch.tensor(Q.toarray(),dtype=torch.float32, device=device)
    for _ in tqdm(range(n_iter)):
        optimizer_mu.zero_grad()
        P_torch_mu = torch_P_according_mu(P_init,mu)
        loss = compute_rsa_differentiable(P_torch_mu,Q_torch,L_sqrt,R_dot_eig)
        loss_l.append(loss.item())
        loss.backward()
        optimizer_mu.step()
        with torch.no_grad():
            new_mu = project_mu_to_uniform(mu,partition_ground)
            mu = mu.copy_(new_mu)
        if keep_historic:
            mu_historic.append(mu.clone().detach().cpu().numpy())
            P_torch_mu_historic.append(P_torch_mu.clone().to_dense().detach().cpu().numpy())
            loss_historic.append(loss.detach().cpu().numpy())
    P_torch_mu_final = torch_P_according_mu(P_init,mu)
    return mu, loss_l, P_torch_mu_final, mu_historic, loss_historic, P_torch_mu_historic


def project_mu_to_uniform(mu:torch.tensor,partition:list) -> torch.tensor:
    """
    Project  the vectors of weights mu to the space of uniform coefficients by cluster 
    args:
        mu : torch.tensor of shape (n_nodes,)
        partition : list of np.ndarray
    return:
        mu_projected : torch.tensor of shape (n_nodes,)
    """
    mu_projected = mu.clone()
    mu_projected = torch.clamp(mu_projected, min=0.0001)
    for i,cluster in enumerate(partition):
        cluster_value = mu_projected[cluster]
        cluster_sum = torch.sum(cluster_value)
        if cluster_sum > 1e-5:
            mu_projected[cluster] = cluster_value/cluster_sum
        else :
            print("Error the cluster has zero sum")
            print("mu[cluster]",mu[cluster])
        if torch.abs(mu_projected[cluster].sum() - 1) > 1e-5:
            print("Error the cluster has not been normalized)", mu_projected[cluster].sum())
    return mu_projected


def minimize_rsa_Q_g(P_init:sp.csr_matrix,L:sp.csr_matrix,R:sp.csr_matrix,
                Q:sp.csr_matrix, P_MP:sp.csr_matrix,
                lr:Optional[float] = 0.01,
                  n_iter:Optional[int] = 100,
                  device:Optional[torch.device] = torch.device('cpu'),
                    name_optim:Optional[str] = "Adam",
                    momentum:Optional[float] = 0.9,
                    keep_historic:Optional[bool] = False,
                    ) -> tuple[torch.tensor,list,torch.tensor, list, list]:
    """"
    maximize the weights of the coarsening matrix P but keeping the constraints of a generalized inverse of Q
    P = P_MP + M(I_N -QP_MP)"
    args:
        L(sp.csr_matrix) : Laplacian matrix of the graph
        R(sp.csr_matrix) : Matrix of the eigenvectors of the graph
        P_init(sp.csr_matrix) : Coarsening matrix
        Q(sp.csr_matrix) : Lifting matrix, helps to define the partition
        P_MP(sp.csr_matrix) : P_MP = Q @ Q^T @ P
        lr(float) : learning rate
        n_iter(int) : number of iterations
        name_optim(str) : name of the optimizer
        momentum(float) : momentum of the optimizer
        keep_historic(bool) : keep 
    return:
        M : final matrix of the coarsening matrix
        loss_l(list) : list of the loss during the optimization
        P_torch_mu_final(torch.tensor) : final coarsening matrix
    """

    loss_historic = []
    P_torch_M_historic = []
    
    M_torch = torch.tensor(P_init.toarray(),dtype=torch.float32,requires_grad=True, device=device)
    P_MP_torch = torch.tensor(P_MP.toarray(),dtype=torch.float32, device=device)
    if name_optim == "Adam":
        optimizer_m = optim.Adam([M_torch],lr = lr)
    if name_optim == "SGD":
        optimizer_m = optim.SGD([M_torch],lr = lr, momentum=momentum)
    if name_optim == "RMSprop":
        optimizer_m = optim.RMSprop([M_torch],lr = lr)
    if name_optim == "RADAM":
        optimizer_m = optim.RAdam([M_torch],lr = lr)
    L_torch = torch.tensor(L.toarray(),dtype=torch.float32, device=device)
    loss_l = []

    n,K = R.shape
    R_torch = torch.tensor(R.toarray(),dtype=torch.float32, device=device)
    eig_val, eig_vec = torch.linalg.eigh(L_torch)

    mask_zero_eig = eig_val < 1e-4
    eig_val[mask_zero_eig] = 0
    eig_val_sqrt = eig_val**0.5
    L_sqrt = eig_vec @ torch.diag(eig_val_sqrt) @ eig_vec.T
    eig_val[mask_zero_eig] = 1
    eig_val_inv_sqrt = eig_val**(-0.5)
    eig_val_inv_sqrt[mask_zero_eig] = 0
    R_dot_eig = R_torch @ torch.diag( eig_val_inv_sqrt[:K])
    Q_torch = torch.tensor(Q.toarray(),dtype=torch.float32, device=device)
    
    
    for _ in tqdm(range(n_iter)):
        optimizer_m.zero_grad()
        P_torch_M = project_generalized_inverse(M_torch, P_MP_torch, Q_torch)
        loss = compute_rsa_differentiable(P_torch_M,Q_torch,L_sqrt,R_dot_eig)
        loss_l.append(loss.item())
        loss.backward()
        optimizer_m.step()
        if keep_historic:
            P_torch_M_historic.append(P_torch_M.clone().to_dense().detach().cpu().numpy())
            loss_historic.append(loss.detach().cpu().numpy())
    P_torch_M_final = project_generalized_inverse(M_torch, P_MP_torch, Q_torch)
    return M_torch, loss_l, P_torch_M_final, loss_historic, P_torch_M_historic


def project_generalized_inverse(M_torch:torch.tensor, P_MP:torch.tensor, Q:torch.tensor) -> torch.tensor:
    """
    Project the matrix M_torch to the space of generalized inverse of Q
    args:
        M_torch(torch.tensor) : matrix to project
        P_MP(torch.tensor) : Moore penrose pseudo inverse of Q
        Q(torch.tensor) : lifting matrix
    return:
        P_projected(torch.tensor) : projected matrix
    """
    M_projected = M_torch.clone()
    I_N = torch.eye(M_torch.shape[1], device=M_torch.device)
    M_projected = P_MP + M_torch @ (I_N - Q @ P_MP)
    return M_projected


def minimize_rsa_Q_g_sparse(P_init:sp.csr_matrix,L:sp.csr_matrix,R:sp.csr_matrix,
                Q:sp.csr_matrix, P_MP:sp.csr_matrix,
                lr:Optional[float] = 0.01,
                  n_iter:Optional[int] = 100,
                    name_optim:Optional[str] = "Adam",
                    momentum:Optional[float] = 0.9,
                    device:Optional[torch.device] = torch.device('cpu'),
                    keep_historic:Optional[bool] = False,
                    penalize_function:Optional[bool] = "l1_norm",
                    lambda_sparse:Optional[float] = 0.01
                    ) -> tuple[torch.tensor,list,torch.tensor, list, list]:
    """"
    maximize the weights of the coarsening matrix P but keeping the constraints of a generalized inverse of Q
    P = P_MP + M(I_N -QP_MP)"
    args:
        L(sp.csr_matrix) : Laplacian matrix of the graph
        R(sp.csr_matrix) : Matrix of the eigenvectors of the graph
        P_init(sp.csr_matrix) : Coarsening matrix
        Q(sp.csr_matrix) : Lifting matrix, helps to define the partition
        P_MP(sp.csr_matrix) : P_MP = Q @ Q^T @ P
        lr(float) : learning rate
        n_iter(int) : number of iterations
        name_optim(str) : name of the optimizer
        momentum(float) : momentum of the optimizer
        keep_historic(bool) : keep oflot of trac"
    return:
        M : final matrix of the coarsening matrix
        loss_l(list) : list of the loss during the optimization
        P_torch_mu_final(torch.tensor) : final coarsening matrix
    """
    P_torch_M_historic = []    
    M_torch = torch.tensor(P_init.toarray(),dtype=torch.float32,requires_grad=True, device=device)
    P_MP_torch = torch.tensor(P_MP.toarray(),dtype=torch.float32, device=device)
    if name_optim == "Adam":
        optimizer_m = optim.Adam([M_torch],lr = lr)
    if name_optim == "SGD":
        optimizer_m = optim.SGD([M_torch],lr = lr, momentum=momentum)
    if name_optim == "RMSprop":
        optimizer_m = optim.RMSprop([M_torch],lr = lr)
    if name_optim == "RADAM":
        optimizer_m = optim.RAdam([M_torch],lr = lr)
    L_torch = torch.tensor(L.toarray(),dtype=torch.float32, device=device)
    loss_combined_l= []
    loss_rsa_l = []
    loss_sparse_l = []

    n,K = R.shape
    R_torch = torch.tensor(R.toarray(),dtype=torch.float32, device=device)
    eig_val, eig_vec = torch.linalg.eigh(L_torch)

    mask_zero_eig = eig_val < 1e-4
    eig_val[mask_zero_eig] = 0
    eig_val_sqrt = eig_val**0.5
    L_sqrt = eig_vec @ torch.diag(eig_val_sqrt) @ eig_vec.T
    eig_val[mask_zero_eig] = 1
    eig_val_inv_sqrt = eig_val**(-0.5)
    eig_val_inv_sqrt[mask_zero_eig] = 0
    R_dot_eig = R_torch @ torch.diag( eig_val_inv_sqrt[:K])
    Q_torch = torch.tensor(Q.toarray(),dtype=torch.float32, device=device)
    
    for _ in tqdm(range(n_iter)):
        optimizer_m.zero_grad()
        P_torch_M = project_generalized_inverse(M_torch, P_MP_torch, Q_torch)
        loss_rsa = compute_rsa_differentiable(P_torch_M,Q_torch,L_sqrt,R_dot_eig)
        if penalize_function == "l1_norm":
            loss_sparse = torch.norm(P_torch_M, p=1)
        elif penalize_function == "l_2_1_norm":
            column_l2 = torch.norm(P_torch_M, p=2, dim=0)   # l2 norm per column
            loss_sparse = column_l2.sum() 
        elif penalize_function == "l1_norm_support":
            mask_support = Q_torch.T > 0
            P_torch_M_masked = P_torch_M * mask_support
            loss_sparse = torch.norm(P_torch_M_masked, p=1)
        else :
            raise ValueError("choose valid penalize_function between l1_norm, l_2_1_norm and difference_1_value_by_column")
        loss_combined = loss_rsa + lambda_sparse * loss_sparse

        loss_combined_l.append(loss_combined.item())
        loss_rsa_l.append(loss_rsa.item())
        loss_sparse_l.append(loss_sparse.item())
        loss_combined.backward()
        optimizer_m.step()
        if keep_historic:
            P_torch_M_historic.append(P_torch_M.clone().to_dense().detach().cpu().numpy())
    P_torch_M_final = project_generalized_inverse(M_torch, P_MP_torch, Q_torch)
    return M_torch, P_torch_M_final, loss_combined_l, loss_rsa_l, loss_sparse_l