import scipy.sparse as sp
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import scipy.sparse as sp
from torch_geometric.data import Data
from torch_geometric.utils.convert import from_scipy_sparse_matrix
from utils.data_utils import Graph, CoarsenedGraph
from typing import Optional
from torch import Tensor
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from typing import Optional, Tuple


class SGC(nn.Module):
    def __init__(self, in_channels:int, out_channels:int, nlayer:int, bias = False):
        super(SGC, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.K = nlayer #order of propagation
        self.bias = bias
        self.W = nn.Linear(in_channels, out_channels, bias = self.bias)       
    def forward(self, x_precomputed: Tensor, edge_index: Tensor, edge_weight: Tensor) -> Tensor:
        return self.W(x_precomputed)

class GCN_adapted(nn.Module):
    def __init__(self, in_channels:int, hidden_channels:list, use_sigmoid:bool, dropout:float, output_dim:int):
        super(GCN_adapted, self).__init__()
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.use_sigmoid = use_sigmoid
        self.nlayer = len(hidden_channels) + 1
        self.dropout = dropout
        self.gcn_layers = nn.ModuleList()
        self.gcn_layers.append(GCNConv(in_channels, hidden_channels[0], normalize=False,add_self_loops=False)) # non normalize and self loops because computed
        if len(hidden_channels) > 1:
            for i in range(1, len(hidden_channels)):
                self.gcn_layers.append(GCNConv(hidden_channels[i-1], hidden_channels[i], normalize=False,add_self_loops=False))
        self.gcn_layers.append(GCNConv(hidden_channels[-1], output_dim, normalize=False,add_self_loops=False))

    def forward(self, x,edge_index,edge_weight):
        for i in range(self.nlayer-1):
            x = self.gcn_layers[i](x, edge_index, edge_weight)
            if self.use_sigmoid:
                x = F.sigmoid(x)
            else:
                x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.gcn_layers[-1](x, edge_index, edge_weight)
        return x
        

def get_propag_matrix(G: Graph, matrix_propag_name: str) -> sp.csr_matrix:
    """
    Get the propagation matrix from the graph.
    Classicall example are self_loop_adj = (A+I_N) normalized, normalized_adj = D^-1/2 A D^-1/2, combinatorial_tuned = I_N - L
    Args:
        G (Graph): graph to get the propagation matrix
        matrix_propag_name (str): name of the propagation matrix
    Returns:
        sp.csr_matrix: propagation matrix
    """
    adj_csr_original = G.csr_adj
    if matrix_propag_name == "self_loop_adj":
        degree_original =  np.array(adj_csr_original.sum(axis=1)).squeeze()
        degree_hat = degree_original + 1
        degree_hat_inv_sqrt = 1 / np.sqrt(degree_hat)
        S_original = sp.diags(degree_hat_inv_sqrt) @ (adj_csr_original + sp.eye(adj_csr_original.shape[0])) @ sp.diags(degree_hat_inv_sqrt)
    elif matrix_propag_name == "normalized_adj":
        degree_original = np.array(adj_csr_original.sum(axis=1)).squeeze()
        degree_hat_inv_sqrt = 1 / np.sqrt(degree_original)
        S_original = sp.diags(degree_hat_inv_sqrt) @ adj_csr_original @ sp.diags(degree_hat_inv_sqrt)
    elif matrix_propag_name == "combinatorial_tuned":
        combinatorial_L = G.get_laplacian(laplacian_name="combinatorial")
        S_original = sp.eye(combinatorial_L.shape[0]) - (combinatorial_L / G.num_nodes)
    else:
        print("error matrix_propag_original not recognized")

    return S_original


def get_propagation_matrix_coarsen(G_c: CoarsenedGraph, S_original: sp.csr_matrix, name_coarsened_propag: str, name_original_propag:Optional[str] = "self_loop_adj") -> sp.csr_matrix:
    """
    Get the oriented propagation matrix. the propagation matrix is supposed despite all space preserved in config , to be (A+I_N) normalized and not the normalized laplacian 
    Args:
        G_c (CoarsenedGraph): coarsened graph
        S_original (sp.csr_matrix): propagation matrix
        name_coarsened_propag (str): name of the coarsened propagation matrix
    Returns:
        sp.csr_matrix: oriented propagation matrix
    """
    P = G_c.P
    Q = G_c.Q
    if name_coarsened_propag == "MP_orientation":
        S_c = P @ S_original @ Q
    elif name_coarsened_propag == "symmetric_classic":
        S_c = Q.T @ S_original @ Q
    elif name_coarsened_propag == "classic":
        #print("classic is here {} but for coarsened graph".format(name_original_propag), flush=True)
        S_c = get_propag_matrix(G_c,matrix_propag_name=name_original_propag)
    elif name_coarsened_propag == "diffpool_inspired":
        S_c = P @ S_original @ P.T
    elif name_coarsened_propag == "diag_cluster":
        if name_original_propag != "self_loop_adj":
            raise ValueError('diag_cluster orientation is only available for self_loop_adj')
        adjacency_c = G_c.csr_adj
        degree_c = np.array(adjacency_c.sum(axis=1)).squeeze()
        Q_array = Q.toarray()
        assignment_cluster = np.where(Q_array > 1e-7, 1, Q_array)
        cluster_reduction_size = np.sum(assignment_cluster.T,axis = 1)
        degree_hat_bis_c = degree_c + cluster_reduction_size
        degree_hat_inv_sqrt_bis_c = 1 / np.sqrt(degree_hat_bis_c)
        S_c = sp.diags(degree_hat_inv_sqrt_bis_c) @ (adjacency_c + sp.diags(cluster_reduction_size)) @ sp.diags(degree_hat_inv_sqrt_bis_c)
    else:
        raise ValueError('method {} not recognized, please choose between MP_orientation, symmetric_classic, classic, diffpool_inspired, diag_cluster'.format(name_coarsened_propag))
    return S_c



def graph_with_propag_to_pyg(G:"Graph", propag_matrix:sp.csr_matrix) -> Data:
    edges_indices_gc, edges_weights_gc = from_scipy_sparse_matrix(propag_matrix)
    if G.labels is None and G.train_mask is None and G.val_mask is None and G.test_mask is None:
        print("Warning: no labels, no train/val/test mask, creating a graph without labels", flush=True)
        G_pyg= Data(x=torch.tensor(G.features.todense(), dtype=torch.float), edge_index=edges_indices_gc, edge_attr=torch.tensor(edges_weights_gc, dtype=torch.float))
    else:
        G_pyg= Data(x=torch.tensor(G.features.todense(), dtype=torch.float), edge_index=edges_indices_gc, edge_attr=torch.tensor(edges_weights_gc, dtype=torch.float), y=torch.tensor(G.labels, dtype=torch.long), train_mask=torch.tensor(G.train_mask, dtype=torch.bool), val_mask=torch.tensor(G.val_mask, dtype=torch.bool), test_mask=torch.tensor(G.test_mask, dtype=torch.bool))
    return G_pyg


def compute_propagated_features(features:sp.csr_matrix, propag_matrix:sp.csr_matrix, nlayer:int) -> sp.csr_matrix:
    """
    Compute the propagated features using the precomputed propagation matrix
    """
    propagated_features = features.copy()
    for i in range(nlayer):
        propagated_features = propag_matrix @ propagated_features 
    return propagated_features



def evaluate_SGC_coarsened( coarsened_graph_pyg:Data, lifting_matrix:torch.Tensor, precomputed_features_gc:torch.Tensor,
        original_graph:Graph, n_epochs:int, lr:float, wd:float, n_layer:int, 
                           device, mean_iter:Optional[int]=10, early_stopping_patience:Optional[int]=None) -> dict:
    """
    Evaluate the model on the coarsened graph
    """
    coarsened_graph_clone = coarsened_graph_pyg.clone()
    coarsened_graph_clone.x = precomputed_features_gc
    dim_features = original_graph.features.shape[1]
    output_dim = original_graph.labels.max().item() + 1
    results_trainings = {}
    results_SGC = []
    torch.manual_seed(42)
    for i in range(mean_iter):
        if device =='cuda':
            torch.cuda.empty_cache()
        SGC_network = SGC(dim_features, output_dim, n_layer)
        optimizer = optim.Adam(SGC_network.parameters(), lr=lr, weight_decay=wd)
        loss_fn = nn.CrossEntropyLoss()
        test_acc_according_to_val_SGC, best_acc_test_SGC, best_loss_val_SGC, best_epoch_val_SGC,best_epoch_test_SGC, _, _, _ = train_coarsen_node_task(
            model=SGC_network, optimizer=optimizer, loss_fn=loss_fn, device=device, n_epochs=n_epochs, 
                original_graph=original_graph, coarsened_graph_pyg=coarsened_graph_clone, lifting_matrix=lifting_matrix, early_stopping_patience=early_stopping_patience)
        results_SGC.append([test_acc_according_to_val_SGC, best_acc_test_SGC, best_loss_val_SGC, best_epoch_val_SGC, best_epoch_test_SGC])
    results_trainings["SGC"] = [np.mean(results_SGC, axis=0).tolist(), np.std(results_SGC, axis=0).tolist()]
    return results_trainings



def train_coarsen_node_task(model:nn.Module, optimizer:optim.Optimizer, loss_fn:nn.Module, n_epochs:int, original_graph:Graph, 
                                coarsened_graph_pyg: Data,  lifting_matrix: torch.Tensor,  device, early_stopping_patience:Optional[int]=None) -> Tuple :
    """"
    Train the model for a node task on the coarsened graph compared to original graph labels 
    """
    model = model.to(device)
    train_mask = torch.tensor(original_graph.train_mask, dtype=torch.bool).to(device)
    val_mask = torch.tensor(original_graph.val_mask, dtype=torch.bool).to(device)
    test_mask = torch.tensor(original_graph.test_mask, dtype=torch.bool).to(device)
    test_loss_l = []
    train_loss_l = []
    test_acc_l = []
    val_acc_l = []
    val_loss_l = []
    labels_original = torch.tensor(original_graph.labels, dtype=torch.long).to(device)
    y = original_graph.labels
    lowest_val_loss = np.inf
    coarsened_graph_pyg = coarsened_graph_pyg.to(device)
    lifting_matrix = lifting_matrix.to(device)
    for current_epoch in range(n_epochs):
        optimizer.zero_grad()
        model.train()
        output = model(coarsened_graph_pyg.x, coarsened_graph_pyg.edge_index, coarsened_graph_pyg.edge_attr)
        output_original = lifting_matrix @ output
        train_loss = loss_fn(output_original[train_mask], labels_original[train_mask])
        train_loss.backward()
        optimizer.step()
        
        model.eval()
        output_val = model(coarsened_graph_pyg.x, coarsened_graph_pyg.edge_index, coarsened_graph_pyg.edge_attr)
        output_original_val = lifting_matrix @ output_val
        val_loss = loss_fn(output_original_val[val_mask], labels_original[val_mask])
        test_loss = loss_fn(output_original_val[test_mask], labels_original[test_mask])
        probabilities_val = F.softmax(output_original_val[val_mask], dim=1)
        predicted_classes_val = probabilities_val.argmax(dim=1)
        val_acc = torch.sum(predicted_classes_val == labels_original[val_mask]).item() / len(labels_original[val_mask])
        probabilities_test = F.softmax(output_original_val[test_mask], dim=1)
        predicted_classes_test = probabilities_test.argmax(dim=1)
        test_acc = torch.sum(predicted_classes_test == labels_original[test_mask]).item() / len(labels_original[test_mask])

        val_loss_l.append(val_loss.detach().cpu().numpy())
        test_loss_l.append(test_loss.detach().cpu().numpy())
        train_loss_l.append(train_loss.detach().cpu().numpy())
        test_acc_l.append(test_acc)
        val_acc_l.append(val_acc)


        if early_stopping_patience is not None:
            if val_loss < lowest_val_loss:
                lowest_val_loss = val_loss
                patience = early_stopping_patience
            else:
                patience -= 1
            if patience == 0:
                break

    best_acc_test = max(test_acc_l)
    best_loss_val = min(val_loss_l)
    best_epoch_val = val_loss_l.index(best_loss_val)
    best_epoch_test = test_acc_l.index(best_acc_test)
    test_acc_according_to_val = test_acc_l[best_epoch_val]
    return test_acc_according_to_val, best_acc_test, best_loss_val, best_epoch_val, best_epoch_test, val_loss_l, test_loss_l, train_loss_l


def evaluate_GCN_coarsened(coarsened_graph_pyg:Data,
                                lifting_matrix:torch.Tensor,
                                original_graph:Graph, n_epochs:int, lr:float, wd:float, hidden_channels:list, 
                           device:Optional[torch.device]= torch.device('cpu'), mean_iter:Optional[int]=10, use_sigmoid:Optional[bool]=False, dropout:Optional[float]=0.5, early_stopping_patience:Optional[int]=None) -> dict:
    """
    Evaluate the model on the coarsened graph
    """
    dim_features = original_graph.features.shape[1]
    output_dim = original_graph.labels.max().item() + 1
    results_trainings = {}
    results_GCN_conv = []
    torch.manual_seed(42)
    for i in range(mean_iter):
        if device == 'cuda':
            torch.cuda.empty_cache()
        GCN_conv_network = GCN_adapted(dim_features, hidden_channels, output_dim =output_dim, use_sigmoid=use_sigmoid, dropout=dropout)
        optimizer = optim.Adam(GCN_conv_network.parameters(), lr=lr, weight_decay=wd)
        loss_fn = nn.CrossEntropyLoss()
        
        test_acc_according_to_val_GCN_conv, best_acc_test_GCN_conv, best_loss_val_GCN_conv, best_epoch_val_GCN_conv,best_epoch_test_GCN_conv, _, _, _ = train_coarsen_node_task(
            model=GCN_conv_network, optimizer=optimizer, loss_fn=loss_fn, device=device, n_epochs=n_epochs, 
                original_graph=original_graph, coarsened_graph_pyg=coarsened_graph_pyg, lifting_matrix=lifting_matrix, early_stopping_patience=early_stopping_patience)
        results_GCN_conv.append([test_acc_according_to_val_GCN_conv, best_acc_test_GCN_conv, best_loss_val_GCN_conv, best_epoch_val_GCN_conv, best_epoch_test_GCN_conv])
    results_trainings["GCN_conv"] = [np.mean(results_GCN_conv, axis=0).tolist(), np.std(results_GCN_conv, axis=0).tolist()]
    return results_trainings