import numpy as np
import scipy.sparse as sp
import torch

from tqdm import tqdm
from utils.coarsen_utils import compute_rsa_exact, coarsen_algo_inspired_loukas
from utils.optim_utils import minimize_rsa_support_l_normalized, minimize_rsa_support, minimize_rsa_Q_g_sparse, minimize_rsa_Q_g, torch_P_according_mu
from utils.training_utils import get_propag_matrix, get_propagation_matrix_coarsen, graph_with_propag_to_pyg,compute_propagated_features, evaluate_GCN_coarsened, evaluate_SGC_coarsened

def rsa_minimization(original_graph,config, laplacian_kind, r_list):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    config.laplacian_norm = laplacian_kind
    config.laplacian_preserved = laplacian_kind

    L_original_graph = original_graph.get_laplacian(laplacian_kind)
    degree_original= np.array(original_graph.csr_adj.sum(axis=1)).squeeze()


    rsa_loukas_list = []
    rsa_mp_list = []
    rsa_opt_list = []
    rsa_support_list = []
    rsa_optim_g_list = []
    rsa_optim_g_sparse_list = []

    for r in tqdm(r_list):
        config.r = r
        print(f"Running for r = {r}")

        coarsened_graph = coarsen_algo_inspired_loukas(original_graph, config)
        Q_lift = coarsened_graph.Q
        R_graph = coarsened_graph.R

        P_loukas = coarsened_graph.P_loukas
        coarsened_graph.P = P_loukas
        rsa_loukas = compute_rsa_exact(coarsened_graph, original_graph, laplacian_name=laplacian_kind)
        rsa_loukas_list.append(rsa_loukas)

        P_opt = coarsened_graph.P_rao
        coarsened_graph.P = P_opt
        rsa_opt = compute_rsa_exact(coarsened_graph, original_graph, laplacian_name=laplacian_kind)
        rsa_opt_list.append(rsa_opt)
        
        P_mp = coarsened_graph.P_mp
        coarsened_graph.P = P_mp
        rsa_mp = compute_rsa_exact(coarsened_graph, original_graph, laplacian_name=laplacian_kind)
        rsa_mp_list.append(rsa_mp)

        row_indices,col_indices = (Q_lift.T).nonzero()
        mu_mp_before_init = torch.zeros(Q_lift.nnz, device=device)
        for i in range(Q_lift.nnz):
            mu_mp_before_init[col_indices[i]] = P_mp[row_indices[i],col_indices[i]]
        mu_mp_init = mu_mp_before_init.clone().detach().requires_grad_(True)

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        if config.laplacian_preserved == "normalized_self_loop":
            degree_coarsened = np.array(coarsened_graph.csr_adj.sum(axis=1)).squeeze()
            mu_mp_normalized, _, _, _ ,loss_historic_mp,_ = minimize_rsa_support_l_normalized(mu_mp_init,L_original_graph,R_graph,P_loukas,Q_lift,lr=0.05, n_iter = 200, keep_historic=True, name_optim="SGD",project_generalized_inverse= True, 
                    degree_coarsened=degree_coarsened, degree_original=degree_original, device=device)
            rsa_support = loss_historic_mp[-1]
            rsa_support_list.append(rsa_support)
        elif config.laplacian_preserved == "combinatorial":
            mu_mp, _, _, _ ,loss_historic_mp,_ = minimize_rsa_support(mu_mp_init,L_original_graph, R_graph, P_loukas,Q_lift,lr=0.05, n_iter = 200, keep_historic=True, name_optim="SGD",device=device)
            rsa_support = loss_historic_mp[-1]
            rsa_support_list.append(rsa_support)
        else:
            raise ValueError("Laplacian type not supported")



        _, _,  P_torch_M_final_mp ,loss_qg_historic_mp,P_torch_M_historic_mp = minimize_rsa_Q_g(P_mp,L_original_graph,R_graph,Q_lift, P_mp,lr=0.01, n_iter = 200, keep_historic=True, name_optim="SGD",device=device)
        P_optim_g = P_torch_M_final_mp.to_dense().detach().cpu().numpy()
        P_optim_g = sp.csr_matrix(P_optim_g)
        P_optim_g.data *= (np.abs(P_optim_g.data) > 0.001)
        P_optim_g.eliminate_zeros()
        coarsened_graph.P = P_optim_g
        rsa_optim_g = compute_rsa_exact(coarsened_graph, original_graph, laplacian_name=laplacian_kind)
        rsa_optim_g_list.append(rsa_optim_g)
        


        _, P_torch_M_final_sparse_l1_rao, loss_combined_sparse_l1_rao, loss_rsa_sparse_l1_rao,loss_sparse_sparse_l1_rao = minimize_rsa_Q_g_sparse(P_opt,
                L_original_graph,R_graph, Q_lift, P_mp, lr = 0.01, n_iter = 200, keep_historic=True, name_optim="SGD", penalize_function="l1_norm", lambda_sparse=0.01, device=device)
        P_optim_g_sparse = P_torch_M_final_sparse_l1_rao.to_dense().detach().cpu().numpy()
        P_optim_g_sparse = sp.csr_matrix(P_optim_g_sparse)
        P_optim_g_sparse.data *= (np.abs(P_optim_g_sparse.data) > 0.001)
        P_optim_g_sparse.eliminate_zeros()
        coarsened_graph.P = P_optim_g_sparse
        rsa_optim_g_sparse = compute_rsa_exact(coarsened_graph, original_graph, laplacian_name=laplacian_kind)
        rsa_optim_g_sparse_list.append(rsa_optim_g_sparse)

    return rsa_loukas_list, rsa_mp_list, rsa_opt_list, rsa_support_list, rsa_optim_g_list, rsa_optim_g_sparse_list
        


def training_exp(original_graph, config_coarsening, r_list):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)


    L_original_graph = original_graph.get_laplacian("normalized_self_loop")
    degree_original= np.array(original_graph.csr_adj.sum(axis=1)).squeeze()
    S_original = get_propag_matrix(original_graph, matrix_propag_name='self_loop_adj')

    results_SGC = {}
    results_SGC["Loukas"] = {}
    results_SGC["MP"] = {}
    results_SGC["Opt"] = {}
    results_SGC["support"] = {}
    results_SGC["optim_g"] = {}
    results_SGC["optim_g_sparse"] = {}
    results_gcn = {}
    results_gcn["Loukas"] = {}
    results_gcn["MP"] = {}
    results_gcn["Opt"] = {}
    results_gcn["support"] = {}
    results_gcn["optim_g"] = {}
    results_gcn["optim_g_sparse"] = {}

    for r in tqdm(r_list):
        config_coarsening.r = r
        print(f"Running for r = {r}")

        coarsened_graph = coarsen_algo_inspired_loukas(original_graph, config_coarsening)
        Q_lift = coarsened_graph.Q
        R_graph = coarsened_graph.R

        P_loukas = coarsened_graph.P_loukas
        P_opt = coarsened_graph.P_rao
        P_mp = coarsened_graph.P_mp

        row_indices,col_indices = (Q_lift.T).nonzero()
        mu_mp_before_init = torch.zeros(Q_lift.nnz,device=device)
        for i in range(Q_lift.nnz):
            mu_mp_before_init[col_indices[i]] = P_mp[row_indices[i],col_indices[i]]
        mu_mp_init = mu_mp_before_init.clone().detach().requires_grad_(True)
        degree_coarsened = np.array(coarsened_graph.csr_adj.sum(axis=1)).squeeze()
        mu_mp_normalized, _, _, _ ,loss_historic_mp,_ = minimize_rsa_support_l_normalized(mu_mp_init,L_original_graph,R_graph,P_loukas,Q_lift,lr=0.05, 
            n_iter = 200, keep_historic=True, name_optim="SGD",project_generalized_inverse= True, 
            degree_coarsened=degree_coarsened, degree_original=degree_original, device=device)
        P_support_torch = torch_P_according_mu(Q_lift.T, mu_mp_normalized)
        P_support = P_support_torch.to_dense().detach().cpu().numpy()
        P_support = sp.csr_matrix(P_support)

        _, _,  P_torch_M_final_mp ,loss_qg_historic_mp,P_torch_M_historic_mp = minimize_rsa_Q_g(P_mp,L_original_graph,R_graph,Q_lift, P_mp,lr=0.01, 
                                                        n_iter = 200, keep_historic=True, name_optim="SGD",device=device)
        P_optim_g = P_torch_M_final_mp.to_dense().detach().cpu().numpy()
        P_optim_g = sp.csr_matrix(P_optim_g)
        P_optim_g.data *= (np.abs(P_optim_g.data) > 0.001)
        P_optim_g.eliminate_zeros()

        _, P_torch_M_final_sparse_l1_rao, loss_combined_sparse_l1_rao, loss_rsa_sparse_l1_rao,loss_sparse_sparse_l1_rao = minimize_rsa_Q_g_sparse(P_opt,
                L_original_graph,R_graph, Q_lift, P_mp, lr = 0.01, n_iter = 200, keep_historic=True, name_optim="SGD", penalize_function="l1_norm", lambda_sparse=0.01, device=device)
        P_optim_g_sparse = P_torch_M_final_sparse_l1_rao.to_dense().detach().cpu().numpy()
        P_optim_g_sparse = sp.csr_matrix(P_optim_g_sparse)
        P_optim_g_sparse.data *= (np.abs(P_optim_g_sparse.data) > 0.001)
        P_optim_g_sparse.eliminate_zeros()

        lifting_matrix = coarsened_graph.get_lifting_torch()
        dict_P_name = {"support":P_support, "MP":P_mp, "Opt":P_opt, "Loukas":P_loukas, "optim_g":P_optim_g, "optim_g_sparse":P_optim_g_sparse}
        print("starting training for r = ", r)
        for name_P, P_chosen in tqdm(dict_P_name.items()):
            print("Training for P = ", name_P)
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            coarsened_graph.P = P_chosen
            coarsened_graph.recompute_features(original_features=original_graph.features, P = P_chosen)
            S_coarsen = get_propagation_matrix_coarsen(coarsened_graph, S_original, name_coarsened_propag="MP_orientation", name_original_propag='self_loop_adj')
            coarsened_graph_pyg = graph_with_propag_to_pyg(coarsened_graph, S_coarsen)
            precomputed_features_gc_array = compute_propagated_features(coarsened_graph.features, S_coarsen, nlayer = 2).toarray()
            precomputed_features_gc_torch = torch.tensor(precomputed_features_gc_array, dtype=torch.float)
            output_sgc = evaluate_SGC_coarsened(coarsened_graph_pyg, lifting_matrix, precomputed_features_gc_torch,
                                                    original_graph,n_epochs=200, lr =0.1, wd = 0.001, n_layer = 2,
                                                            device=device,mean_iter=10, early_stopping_patience=1000)
            output_gcn = evaluate_GCN_coarsened(coarsened_graph_pyg,lifting_matrix,original_graph,n_epochs=200,
                                                lr=0.01, wd = 0.001, hidden_channels=[256,128],
                                                device=device,mean_iter=10,use_sigmoid=False,dropout=0.5,
                                                early_stopping_patience=100)
            sgc_accuracy = output_sgc['SGC'][0][0]
            sgc_std = output_sgc['SGC'][1][0]
            gcn_accuracy = output_gcn['GCN_conv'][0][0]
            gcn_std = output_gcn['GCN_conv'][1][0]
            results_SGC[name_P][r] = {}
            results_gcn[name_P][r] = {}
            results_SGC[name_P][r]["acc"] = sgc_accuracy
            results_SGC[name_P][r]["std"] = sgc_std
            results_gcn[name_P][r]["acc"] = gcn_accuracy
            results_gcn[name_P][r]["std"] = gcn_std
        
    return results_SGC, results_gcn

            






