import os
import torch
import scipy.sparse as sp
import numpy as np
import json 
import copy 
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
from typing import Optional, Any
from torch_geometric.utils import to_scipy_sparse_matrix
from scipy.sparse.csgraph import connected_components
from matplotlib.cm import ScalarMappable
from torch_geometric.data import Data
from torch_geometric.utils.convert import from_scipy_sparse_matrix
from torch_geometric.datasets import planetoid
from torch_geometric.transforms import LargestConnectedComponents


    

def load_planetoid(graph_name,path_data, split, largest_connected_component = True): 
    path_data = path_data + '/planetoid'
    if split != 'public':
        print("We don't use the public split")
        print("We use the split:", split)
    graph_data = planetoid.Planetoid(root = path_data, name = graph_name, split = split)
    graph_data = graph_data[0]
    if largest_connected_component:
        print("We keep the largest connected component")
        print("num nodes before largest connected component:", graph_data.num_nodes)
        graph_data = LargestConnectedComponents()(graph_data)
        print("num nodes after largest connected component:", graph_data.num_nodes)

    if graph_data.edge_attr is None:
        print("edge_attr is None, we add 1 to all edges")
        graph_data.edge_attr = torch.ones(graph_data.edge_index.shape[1], 1)
    graph_sparse_adj = to_scipy_sparse_matrix(graph_data.edge_index, edge_attr = graph_data.edge_attr, num_nodes = graph_data.num_nodes)
    graph_csr_adj = graph_sparse_adj.tocsr()
    graph_features = graph_data.x.numpy()
    graph_features_csr = sp.csr_matrix(graph_features)
    graph_labels = graph_data.y.numpy()
    train_mask = graph_data.train_mask.numpy()
    val_mask = graph_data.val_mask.numpy()
    test_mask = graph_data.test_mask.numpy()
    my_class_graph = Graph(csr_adj = graph_csr_adj, features = graph_features_csr, train_mask = train_mask, val_mask = val_mask, test_mask = test_mask, labels = graph_labels)
    return my_class_graph



def load_graph(graph_name, split:Optional[str] = 'public', keep_largest_connected_component:Optional[bool] = True, path:Optional[str] = 'data/original_datasets'):
    if graph_name in ['PubMed', 'CiteSeer', 'Cora']:
        return load_planetoid(graph_name, path, split = split, largest_connected_component = keep_largest_connected_component)
    elif graph_name == "random_geometric":
        print("Random geometric graph created seed=42", flush=True)
        graph_geom = create_geometric_graph(n_nodes=1000, threshold_link=0.06, force_connected=True, seed=42)
        return graph_geom
    else:
        print('Graph name not recognized : ', graph_name)
        raise ValueError('Please choose between PubMed, CiteSeer, Cora and random_geometric')
  

def create_geometric_graph(n_nodes, threshold_link, seed = 42, device = 'cpu', force_connected = True):
    """
    Create a geometric graph. points random in [0,1] x [0,1] and connect nodes if the distance is less than threshold_link.
    Args:
        n_nodes: number of nodes
        threshold_link: threshold to connect nodes
        seed: random seed
        device: device to store the graph
        force_connected: force the graph to be connected
    Returns:
        G: geometric graph
    """
    # initialize randomly coordinates of all the nodes in [0,1] x [0,1]
    np.random.seed(seed)
    coord = np.random.rand(n_nodes,2)
    adj_csr = sp.csr_matrix((n_nodes,n_nodes))
    for i in range(n_nodes):
        for j in range(i+1,n_nodes):
            if np.linalg.norm(coord[i] - coord[j]) < threshold_link:
                adj_csr[i,j] = 1
                adj_csr[j,i] = 1

    if force_connected:
        n_components, labels = connected_components(csgraph = adj_csr)
        if n_components > 1:
            biggest_cc = np.argmax(np.bincount(labels))
            nodes_not_in_biggest_cc = np.where(labels != biggest_cc)[0]
            #find the closest element in the biggest connected component for each node not in the biggest connected component
            for node in nodes_not_in_biggest_cc:
                closest_element = np.argmin(np.linalg.norm(coord[node] - coord[labels == biggest_cc], axis = 1))
                adj_csr[node,closest_element] = 1
                adj_csr[closest_element,node] = 1
                #add the node in the biggest connected component
                labels[node] = biggest_cc
            nodes_forced_connected = len(nodes_not_in_biggest_cc)
        else:
            nodes_forced_connected = 0
        print(f"number of nodes forced connected {nodes_forced_connected}")

    my_type_graph = Graph(csr_adj = adj_csr, pos=coord)
    return my_type_graph



class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)
    def __getattr__(self, name):
        if name in self.__dict__:
            return self.__dict__[name]
        else:
            
            #print("Warning: the attribute", name, "does not exist")
            return None
    def __deepcopy__(self, memo):
        new_copy = type(self)()
        memo[id(self)] = new_copy  
        for key, value in self.__dict__.items():
            try:
                new_copy.__dict__[key] = copy.deepcopy(value, memo)
            except Exception:
                print(f"Warning: could not deepcopy {key} of type {type(value)}")
                new_copy.__dict__[key] = value  
        return new_copy





class Graph():
    def __init__(self, csr_adj:sp.csr_matrix, features:Optional[sp.csr_matrix]=None, train_mask:Optional[np.ndarray]=None, val_mask:Optional[np.ndarray]=None, 
                 test_mask:Optional[np.ndarray]=None, labels:Optional[np.ndarray]=None, pos:Optional[Any]=None, laplacian_combinatorial:Optional[sp.csr_matrix]=None, normalized_laplacian:Optional[sp.csr_matrix]=None, normalized_self_loop_laplacian:Optional[sp.csr_matrix]=None) -> None:
        self.csr_adj = csr_adj
        self.features = features
        self.train_mask = train_mask
        self.val_mask = val_mask
        self.test_mask = test_mask
        self.labels = labels
        self.pos = pos
        self.laplacian_combinatorial = laplacian_combinatorial
        self.normalized_laplacian = normalized_laplacian
        self.normalized_self_loop_laplacian = normalized_self_loop_laplacian
        if features is None:
            print("Warning: no features, no downstream task possible")
    
    def show_info(self):
        print("Number of nodes:", self.csr_adj.shape[0])
        print("Number of edges:", self.csr_adj.nnz)
        if self.features is not None:
            print("Number of features:", self.features.shape[1])
        if self.labels is not None:
            print("Number of classes:", len(np.unique(self.labels)))
        if self.train_mask is not None:
            print("Number of training nodes:", np.sum(self.train_mask))
        if self.val_mask is not None:
            print("Number of validation nodes:", np.sum(self.val_mask))
        if self.test_mask is not None:
            print("Number of test nodes:", np.sum(self.test_mask))

    def to_pyg(self):
        edges_indices, edges_weights = from_scipy_sparse_matrix(self.csr_adj)
        G_torch = Data(x=torch.tensor(self.features.todense(), dtype=torch.float), edge_index=edges_indices, edge_attr=torch.tensor(edges_weights, dtype=torch.float), y=torch.tensor(self.labels, dtype=torch.long), train_mask=torch.tensor(self.train_mask, dtype=torch.bool), val_mask=torch.tensor(self.val_mask, dtype=torch.bool), test_mask=torch.tensor(self.test_mask, dtype=torch.bool))
        return G_torch

    def plot(self, title = None, seed=42, node_size=10, width=0.6,  print_stats = False):
        """
        Plot a graph.
        Args:
            G: graph, of type Graph
            title: title of the plot
            seed: random seed
            node_size: size of the nodes
            width: width of the edges
            print_stats: print the distribution of the edge weights
        Returns:
            fig: figure of the plot
        """
        fig = plt.figure(figsize=(4, 4))
        ax = fig.add_subplot(1, 1, 1)
        ax.axis("off")
        if title is not None:
            ax.set_title(title)
        G_networkx = nx.from_numpy_array(self.csr_adj.toarray())
        if self.pos is None:
            self.pos = nx.spring_layout(G_networkx, seed=seed)
        
        n_edges = self.csr_adj.nnz
        same_weight = True
        for i in range(n_edges):
            if self.csr_adj.data[i] != 0 and self.csr_adj.data[i] != 1:
                same_weight = False
                break
        if same_weight:
            print("all edge same value")
            weight = [width for i in range(n_edges)]
            edge_color = 'grey'
            nx.draw(G_networkx, node_size=node_size, width=weight, edge_color=edge_color, arrows = False, pos=self.pos,node_color = '#1f78b4')
        else:
            raise ValueError("Not implemented yet")
            

        fig.tight_layout()
        return fig
    def get_laplacian(self, laplacian_name="combinatorial"):
        if laplacian_name == "combinatorial":
            if self.laplacian_combinatorial is None:
                print("Computing combinatorial Laplacian: D-A")
                degree_matrix = np.array(self.csr_adj.sum(axis=1)).squeeze()
                self.laplacian_combinatorial = sp.diags_array(degree_matrix) - self.csr_adj
                #the type of the laplacian is csr_matrix
            return self.laplacian_combinatorial
        elif laplacian_name == "normalized":
            if self.normalized_laplacian is None:
                print("Computing normalized Laplacian D^(-0.5) * (D-A) * D^(-0.5)")
                degree_matrix = np.array(self.csr_adj.sum(axis=1)).squeeze()
                #degree_matrix = np.diag(np.sum(self.csr_adj, axis=1))
                degree_matrix_inv = degree_matrix**(-0.5)
                degree_matrix_inv = sp.diags_array(degree_matrix_inv)
                self.normalized_laplacian = degree_matrix_inv @ (sp.diags_array(degree_matrix) - self.csr_adj) @ degree_matrix_inv
            return self.normalized_laplacian
        elif laplacian_name == "normalized_self_loop":
            if self.normalized_self_loop_laplacian is None:
                print("Computing normalized Laplacian with self loop : (D+I)^(-0.5) (A+I) (D+I)^(-0.5)")
                degree_matrix = np.array(self.csr_adj.sum(axis=1)).squeeze()
                #degree_matrix = np.diag(np.sum(self.csr_adj, axis=1))
                degree_matrix_self_loop = degree_matrix + 1
                degree_matrix_self_loop_inv = degree_matrix_self_loop**(-0.5)
                degree_matrix_self_loop_inv = sp.diags_array(degree_matrix_self_loop_inv)
                self.normalized_self_loop_laplacian = degree_matrix_self_loop_inv @ (sp.diags_array(degree_matrix) - self.csr_adj) @ degree_matrix_self_loop_inv
            return self.normalized_self_loop_laplacian
        else:
            raise ValueError(f"Unknown Laplacian type: {laplacian_name}, consider using 'combinatorial', 'normalized' or 'normalized_self_loop' or implement the new type of Laplacian")


class CoarsenedGraph(Graph):
    def __init__(self,  csr_adj:sp.csr_matrix, P:sp.csr_matrix, Q:Optional[sp.csr_matrix]=None, R:Optional[sp.csr_matrix]=None, method_name:Optional[str]=None, method_config:Optional[Any]=None, features:Optional[sp.csr_matrix]=None, train_mask:Optional[np.ndarray]=None, val_mask:Optional[np.ndarray]=None, test_mask:Optional[np.ndarray]=None, labels:Optional[np.ndarray]=None, pos:Optional[Any]=None,
                 P_rao:Optional[sp.csr_matrix]=None, P_loukas:Optional[sp.csr_matrix]=None,
                 P_mp:Optional[sp.csr_matrix]=None,
                 P_optim_g:Optional[sp.csr_matrix]=None,
                 P_optim_g_sp:Optional[sp.csr_matrix]=None,
                 P_optim_supp:Optional[sp.csr_matrix]=None,
                 P_optim_g_sp_supp:Optional[sp.csr_matrix]=None,
                    P_optim_g_sp_threshold:Optional[sp.csr_matrix]=None,
                    P_optim_g_sp_supp_threshold:Optional[sp.csr_matrix]=None,
                P_optim_g_sp_supp_threshold_low:Optional[sp.csr_matrix]=None,
                P_optim_g_sp_threshold_low:Optional[sp.csr_matrix]=None,
                P_optim_g_sp_supp_threshold_mid:Optional[sp.csr_matrix]=None,
                P_optim_g_sp_threshold_mid:Optional[sp.csr_matrix]=None
                 ) -> None:
        super().__init__(csr_adj, features, train_mask, val_mask, test_mask, labels, pos)
        self.P = P #coarsening matrix
        self.R = R
        if Q is None:
            print("Q is not provided, we compute the pseudo inverse of the coarsening matrix to be the lifting matrix")
            self.Q = np.linalg.pinv(P.todense())
        else:
            self.Q = Q #let Q be the lifting matrix
        self.method_name = method_name
        self.method_config = method_config
        self.P_rao = P_rao
        self.P_loukas = P_loukas
        self.P_mp = P_mp
        self.P_optim_g = P_optim_g
        self.P_optim_g_sp = P_optim_g_sp
        self.P_optim_supp = P_optim_supp
        self.P_optim_g_sp_supp = P_optim_g_sp_supp
        self.P_optim_g_sp_threshold = P_optim_g_sp_threshold
        self.P_optim_g_sp_supp_threshold = P_optim_g_sp_supp_threshold
        self.P_optim_g_sp_supp_threshold_low = P_optim_g_sp_supp_threshold_low
        self.P_optim_g_sp_threshold_low = P_optim_g_sp_threshold_low
        self.P_optim_g_sp_supp_threshold_mid = P_optim_g_sp_supp_threshold_mid
        self.P_optim_g_sp_threshold_mid = P_optim_g_sp_threshold_mid
    
    def show_info(self):
        super().show_info()
        print("Number of nodes before coarsening:", self.P.shape[1])
        if self.method_name is not None:
            print("Method name:", self.method_name)
            print("")
        if self.method_config is not None:
            #print all the value in the dictionary
            print("Method configuration:")
            for key, value in self.method_config.items():
                print(key, ":", value)

    def save(self,name_exp:str, path:Optional[str]="data/saved_coarsened_graph/"):
        if self.method_name is not None:
            path = path + self.method_name + "/"
        path = path + name_exp + "/"

        print("Saving the coarsened graph in", path)
        os.makedirs(path, exist_ok=True)
        sp.save_npz(path + "csr_adj.npz", self.csr_adj)
        sp.save_npz(path + "P.npz", self.P)
        sp.save_npz(path + "Q.npz", self.Q)
        if self.R is not None:
            sp.save_npz(path + "R.npz", self.R)
        if self.features is not None:
            sp.save_npz(path + "features.npz", self.features)
        if self.train_mask is not None:
            np.save(path + "train_mask.npy", self.train_mask)
        if self.val_mask is not None:
            np.save(path + "val_mask.npy", self.val_mask)
        if self.test_mask is not None:
            np.save(path + "test_mask.npy", self.test_mask)
        if self.labels is not None:
            np.save(path + "labels.npy", self.labels)
        if self.method_config is not None:
            with open(path + "method_config.json", 'w') as f:
                json.dump(self.method_config.__dict__, f, indent=4)
        if self.P_rao is not None:
            sp.save_npz(path + "P_rao.npz", self.P_rao)
        if self.P_loukas is not None:
            sp.save_npz(path + "P_loukas.npz", self.P_loukas)
        if self.P_mp is not None:
            sp.save_npz(path + "P_mp.npz", self.P_mp)
        if self.P_optim_g is not None:
            sp.save_npz(path + "P_optim_g.npz", self.P_optim_g)
        if self.P_optim_g_sp is not None:
            sp.save_npz(path + "P_optim_g_sp.npz", self.P_optim_g_sp)
        if self.P_optim_supp is not None:
            sp.save_npz(path + "P_optim_supp.npz", self.P_optim_supp)
        if self.P_optim_g_sp_supp is not None:
            sp.save_npz(path + "P_optim_g_sp_supp.npz", self.P_optim_g_sp_supp)
        if self.P_optim_g_sp_threshold is not None:
            sp.save_npz(path + "P_optim_g_sp_threshold.npz", self.P_optim_g_sp_threshold)
        if self.P_optim_g_sp_supp_threshold is not None:
            sp.save_npz(path + "P_optim_g_sp_supp_threshold.npz", self.P_optim_g_sp_supp_threshold)
        if self.P_optim_g_sp_supp_threshold_low is not None:
            sp.save_npz(path + "P_optim_g_sp_supp_threshold_low.npz", self.P_optim_g_sp_supp_threshold_low)
        if self.P_optim_g_sp_threshold_low is not None:
            sp.save_npz(path + "P_optim_g_sp_threshold_low.npz", self.P_optim_g_sp_threshold_low)
        if self.P_optim_g_sp_supp_threshold_mid is not None:
            sp.save_npz(path + "P_optim_g_sp_supp_threshold_mid.npz", self.P_optim_g_sp_supp_threshold_mid)
        if self.P_optim_g_sp_threshold_mid is not None:
            sp.save_npz(path + "P_optim_g_sp_threshold_mid.npz", self.P_optim_g_sp_threshold_mid)
    
    def get_lifting_torch(self):  
        #to be improved
        Q_dense = self.Q.toarray() 
        Q_torch = torch.tensor(Q_dense, dtype=torch.float32)
        Q_torch.to_sparse()
        return Q_torch

    def recompute_features(self, original_features:sp.csr_matrix, 
                           P:Optional[sp.csr_matrix]=None) -> None:
        """
        Recompute the features of the coarsened graph using the coarsening matrix P.
        """
        if P is None:
            print("No P provided, using the coarsening matrix of the coarsened graph")
            P = self.P
        self.features = P @ original_features
    