import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, SAGEConv
from typing import List
import torch.nn as nn
import math
from einops import rearrange
import random
import ot
import matplotlib as mpl
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import scipy.sparse as sp
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.manifold import TSNE
import os


device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')


import numpy as np

def personalized_pagerank(adj_matrix, alpha=0.85, tol=1e-6, max_iter=100, personalize=None):
    """
    Compute Personalized PageRank for each node in a graph.

    Parameters:
    - adj_matrix (numpy.ndarray): Adjacency matrix of the graph.
    - alpha (float): Damping factor (default is 0.85).
    - tol (float): Convergence tolerance (default is 1e-6).
    - max_iter (int): Maximum number of iterations (default is 100).
    - personalize (numpy.ndarray): Personalization vector (default is uniform).

    Returns:
    - ppr (numpy.ndarray): PPR scores for each node.
    """
    # Number of nodes
    n = adj_matrix.shape[0]

    # Row-normalize adjacency matrix to create the transition matrix
    row_sums = adj_matrix.sum(axis=1)
    transition_matrix = np.divide(adj_matrix, row_sums[:, None], where=row_sums[:, None] != 0)

    # Default uniform personalization if not provided
    if personalize is None:
        personalize = np.ones(n) / n
    
    # Ensure the personalization vector is normalized
    personalize = personalize / np.sum(personalize)

    # Initialize the rank vector
    rank = np.ones(n) / n  # Start with uniform distribution

    for _ in range(max_iter):
        new_rank = (1 - alpha) * personalize + alpha * transition_matrix.T @ rank
        
        # Check convergence
        if np.linalg.norm(new_rank - rank, 1) < tol:
            break

        rank = new_rank

    return rank


'''Classes'''
# sin activation
class Sine(nn.Module):
    def __init__(self, w0 = 1.):
        super().__init__()
        self.w0 = w0
    def forward(self, x):
        return torch.sin(self.w0 * x)

# Siren layer:
class Siren(nn.Module):
    def __init__(self, dim_in, dim_out, w0 = 1., c = 6., is_first = False, use_bias = True, activation = 'sine'):
        super().__init__()
        self.dim_in = dim_in
        self.is_first = is_first
        self.act = activation

        weight = torch.zeros(dim_out, dim_in)
        bias = torch.zeros(dim_out) if use_bias else None
        self.init_(weight, bias, c = c, w0 = w0)

        self.weight = nn.Parameter(weight)
        self.bias = nn.Parameter(bias) if use_bias else None

       
        if activation=='sine':
            self.activation=Sine(w0)
        elif activation=='relu':
            self.activation=nn.ReLU(inplace=False)
        elif activation=='id':
            self.activation=nn.Identity()
        elif activation=='sigmoid':
            self.activation=nn.Sigmoid()
        else:
            raise ValueError('No mlp activation specified')


    def init_(self, weight, bias, c, w0):
        dim = self.dim_in
        act = self.act

        if act =='relu':
            w_std = math.sqrt(1/dim)
        else:
            w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
        
        weight.uniform_(-w_std, w_std)
        if exists(bias):
            bias.uniform_(-w_std, w_std)

    def forward(self, x):
        out =  F.linear(x, self.weight, self.bias)
        out = self.activation(out)
        return out

# Siren network:    
class SirenNet(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out, num_layers, w0 = 1., w0_initial = 30., use_bias = True, activation = 'sine', final_activation = 'sigmoid'):
        super().__init__()

        self.dim_hidden = dim_hidden
        self.num_layers = len(dim_hidden)
        self.layers = nn.ModuleList([])
        for ind in range(self.num_layers):
            is_first = ind == 0
            layer_w0 = w0_initial if is_first else w0
            layer_dim_in = dim_in if is_first else dim_hidden[ind-1]
            self.layers.append(Siren(
                dim_in = layer_dim_in,
                dim_out = dim_hidden[ind],
                w0 = layer_w0,
                use_bias = use_bias,
                is_first = is_first,
                activation = activation
            ))

        final_activation = 'id' if not exists(final_activation) else final_activation
        self.last_layer = Siren(dim_in = dim_hidden[num_layers-1], dim_out = dim_out, w0 = w0, use_bias = use_bias, activation = final_activation)

    def forward(self, x, mods = None):
        mods = cast_tuple(mods, self.num_layers)

        for layer, mod in zip(self.layers, mods):
            x = layer(x)

            if exists(mod):
                x = x*rearrange(mod, 'd -> () d')   
                # x = x * mod

        return self.last_layer(x)

# GNN network using GCNConv:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_dimensions):
        super(GCN, self).__init__()        
        self.layers = torch.nn.ModuleList()        
        self.layers.append(GCNConv(in_channels, hidden_channels[0]))        
        for i in range(1, len(hidden_channels)):
            self.layers.append(GCNConv(hidden_channels[i-1], hidden_channels[i]))  
        
        # Final fully connected layer
        self.fc1 = torch.nn.Linear(hidden_channels[-1], num_dimensions)        
    def forward(self, x, edge_index):
        for conv in self.layers:
            x = conv(x, edge_index)
            x = F.relu(x)
        x_pre_fc = x
        x_post_fc = self.fc1(x_pre_fc)
        x_post_fc = torch.sigmoid(x_post_fc)
        x_post_fc = (x_post_fc - x_post_fc.min())/(x_post_fc.max() - x_post_fc.min())        

        return x_pre_fc, x_post_fc
    
# ISGL network without pooling:
class ISGL(nn.Module):
    def __init__(self, model1, model2):
        super(ISGL, self).__init__()
        self.model1 = model1
        self.model2 = model2

    def forward(self, g):
        output_gnn_pre, output_gnn_post = self.model1(g.x, g.edge_index)
        n = output_gnn_post.size(0)
        z = output_gnn_post.squeeze(-1)
        
        x_coords = z.repeat_interleave(n)
        y_coords = z.repeat(n)
        edge_coords = torch.stack([x_coords, y_coords], dim=1)
        edge_coords = edge_coords.to(device)

        output_inr = self.model2(edge_coords)
        return output_inr, output_gnn_post

# GNN network:
class GNN_class(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GNN_class, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)  # Log softmax for classification
    
# MLP network:
class MLP(nn.Module):
    def __init__(self, layer_sizes):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(layer_sizes) - 2):  # Hidden layers
            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1]))
            layers.append(nn.Softmax())  # or another activation function
        
        layers.append(nn.Linear(layer_sizes[-2], layer_sizes[-1]))  # Final layer
        
        # Apply Softmax only if it's a classification problem
        self.is_classification = True  # Change this depending on your problem
        if self.is_classification:
            layers.append(nn.Softmax(dim=1))
        
        self.mlp = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.mlp(x)
    
'''Functions'''
def nx2torch2(adjacency_matrix, features=None):
    if sp.issparse(adjacency_matrix):
        adjacency_matrix = adjacency_matrix.toarray()  # Convert sparse matrix to dense    
    adjacency_matrix = torch.tensor(adjacency_matrix, dtype=torch.float)
    edge_index = torch.nonzero(adjacency_matrix, as_tuple=False).t().contiguous()
    edge_value = adjacency_matrix.reshape(-1)
    if features is None:
        x = torch.randn(adjacency_matrix.size(0), 1)
    else:
        x = torch.tensor(features, dtype=torch.float)
    # x = torch.tensor([adjacency_matrix[i,:].sum() for i in range(adjacency_matrix.size(0))], dtype=torch.float).view(-1, 1)
    data = Data(x=x, edge_index=edge_index, y=edge_value).to(device)
    return data


def graph2XY(graphs, num_nodes, model_ISGL, latent_val=None, hsize=4, sortDeg=0):
    X_all = []
    y_all = []
    w_all = []

    # Loop through each graph
    for graph_idx in range(len(graphs)):
        # Sorting:
        graph_test = graphs[graph_idx]
        if sp.issparse(graph_test):
            graph_test = graph_test.toarray()  # Convert sparse matrix to dense
        weight = graph_test.shape[0] / num_nodes
        
        # get perm based on degree
        if sortDeg:
            perm = np.argsort([np.sum(graph_test[:,j]) for j in range(graph_test.shape[0])])
        else:
            data_i = nx2torch2(graph_test)
            _, output_gnn_post = model_ISGL.model1(data_i.x, data_i.edge_index)
            perm_i = torch.argsort(output_gnn_post.squeeze(-1))
            perm = perm_i.cpu().numpy()  # Convert torch permutation to numpy array


        # # get perm based on clustering coefficient
        # cluster_dict = nx.clustering(nx.from_numpy_matrix(graph_test))
        # perm = np.argsort([cluster_dict[j] for j in range(graph_test.shape[0])])

        # # get perm based on pagerank
        # page_rank_dict = nx.pagerank(nx.from_numpy_matrix(graph_test))
        # perm = np.argsort([page_rank_dict[j] for j in range(graph_test.shape[0])])

        # # get perm based on personalized pagerank
        # personalization_vector = np.zeros(graph_test.shape[0])
        # pp_rank = np.zeros(graph_test.shape[0])
        # for i in range(graph_test.shape[0]):
        #     personalization_vector[i] = 1  # Personalize to the node i
        #     pp_rank_i = personalized_pagerank(graph_test, personalize=personalization_vector)
        #     pp_rank[i] = np.mean(pp_rank_i)
        # perm = np.argsort(pp_rank) 

        
        graph_test_sort = graph_test[perm, :][:, perm]
        # Average pooling to have the histogram
        N = graph_test_sort.shape[0]
        if hsize == 1:
            h = 1
        if hsize == 2:
            h = 2
        if hsize == 3:
            h = 3
        if hsize == 4:
            h = int(np.log(N))
        if hsize == 5:
            h = int(np.sqrt(N))
        h = max(1, h)
        A = torch.tensor(graph_test_sort, dtype=torch.float, device=device)
        A = A.view(1, 1, N, N)
        H = F.avg_pool2d(A, kernel_size=(h, h))
        H = H.cpu().numpy().squeeze()  # Convert pooled result back to numpy
        k = H.shape[0]

        # Get the coordinates:
        x_coord = (np.arange(k) + 0.5) / k
        y_coord = (np.arange(k) + 0.5) / k
        xx, yy = np.meshgrid(x_coord, y_coord)
        upper_triangle_mask = np.triu(np.ones((k, k), dtype=bool))
        X_upper_triangle = np.column_stack((xx[upper_triangle_mask], yy[upper_triangle_mask]))
        y_upper_triangle = H[upper_triangle_mask].reshape(-1, 1)

        # X = np.column_stack((xx.ravel(), yy.ravel()))
        # y = H.ravel().reshape(-1, 1)
        X = X_upper_triangle
        y = y_upper_triangle
        w = np.tile(weight, (X.shape[0], 1))
        if latent_val is not None:
            z_this = latent_val[graph_idx]
            zz = np.tile(z_this, (X.shape[0], 1))
            X = np.column_stack((X, zz))

        X_all.append(X)
        y_all.append(y)
        w_all.append(w)

    # Convert lists to tensors
    X_all = torch.tensor(np.concatenate(X_all, axis=0), dtype=torch.float, device=device)
    y_all = torch.tensor(np.concatenate(y_all, axis=0), dtype=torch.float, device=device)
    w_all = torch.tensor(np.concatenate(w_all, axis=0), dtype=torch.float, device=device)

    return X_all, y_all, w_all, perm

def exists(val):
    return val is not None

def cast_tuple(val, repeat = 1):
    return val if isinstance(val, tuple) else ((val,) * repeat)

def synthesize_graphon(r: int = 1000, type_idx: int = 0, alpha: int = 1) -> np.ndarray:
    """
    Synthesize graphons
    :param r: the resolution of discretized graphon
    :param type_idx: the type of graphon
    :return:
        w: (r, r) float array, whose element is in the range [0, 1]
    """
    u = ((np.arange(0, r) + 1) / r).reshape(-1, 1)  # (r, 1)
    v = ((np.arange(0, r) + 1) / r).reshape(1, -1)  # (1, r)

    if type_idx == 0:
        u = u[::-1, :]
        v = v[:, ::-1]
        w = u @ v
    elif type_idx == 1:
        w = np.exp(-(u ** 0.7 + v ** 0.7))
    elif type_idx == 2:
        u = u[::-1, :]
        v = v[:, ::-1]
        w = 0.25 * (u ** 2 + v ** 2 + u ** 0.5 + v ** 0.5)
    elif type_idx == 3:
        u = u[::-1, :]
        v = v[:, ::-1]
        w = 0.5 * (u + v)
    elif type_idx == 4:
        u = u[::-1, :]
        v = v[:, ::-1]
        w = 1 / (1 + np.exp(-2 * (u ** 2 + v ** 2)))
    elif type_idx == 5:
        u = u[::-1, :]
        v = v[:, ::-1]
        w = 1 / (1 + np.exp(-(np.maximum(u, v) ** 2 + np.minimum(u, v) ** 4)))
    elif type_idx == 6:
        w = np.exp(-np.maximum(u, v) ** 0.75)
    elif type_idx == 7:
        w = np.exp(-0.5 * (np.minimum(u, v) + u ** 0.5 + v ** 0.5))
    elif type_idx == 8:
        u = u[::-1, :]
        v = v[:, ::-1]
        w = np.log(1 + 0.5 * np.maximum(u, v))
    elif type_idx == 9:
        w = np.abs(u - v)
    elif type_idx == 10:
        w = 1 - np.abs(u - v)
    elif type_idx == 101:
        theta = np.pi * 3 / 4
        w = 0.9*np.exp((-(v)**2-(u-1)**2)/alpha**2)+0.9*np.exp((-(u)**2-(v-1)**2)/alpha**2)+0.9*np.exp(-((np.sin(theta)*u+np.cos(theta)*v)/alpha)**2)
    elif type_idx == 11:
        r2 = int(r / 2)
        w = np.kron(np.eye(2, dtype=int), 0.8 * np.ones((r2, r2)))
    elif type_idx == 12:
        r2 = int(r / 2)
        w = np.kron(np.eye(2, dtype=int), np.ones((r2, r2)))
        w = 0.8 * (1 - w)
    elif type_idx == 19:
        w = np.exp(-(u ** 0.7 + v ** 0.7)/alpha)
    elif type_idx == 14:
        w = np.ones((r, r))
        w = 0.2 * w 
    else:
        raise ValueError('Unknown graphon type')

    np.fill_diagonal(w, 0.)
    return w

def synthesize_SBM_graphon(r: int = 1000, p1: float = 0.8, p2: float = 0.8, q: float = 0.1, s: float = 0.5) -> np.ndarray:
    """
    Synthesize SBM graphons
    :param r: the resolution of discretized graphon
    :param p: the prob in SBM
    :param s: the portion of SBM to be p
    :return:
        w: (r, r) float array, whose element is in the range [0, 1]
    """
    w = np.zeros((r, r))
    R = int(r * s)
    w[:R, :R] = p1
    w[R:, R:] = p2
    w[:R, R:] = q
    w[R:, :R] = q
    return w

def simulate_graphs(w: np.ndarray, seed_gsize: int=123, seed_edge:int=123, num_graphs: int = 10,
                    num_nodes: int = 200, graph_size: str = 'fixed', offset:int=0) -> List[np.ndarray]:
    """
    Simulate graphs based on a graphon
    :param w: a (r, r) discretized graphon
    :param num_graphs: the number of simulated graphs
    :param num_nodes: the number of nodes per graph
    :param graph_size: fix each graph size as num_nodes or sample the size randomly as num_nodes * (0.5 + uniform)
    :return:
        graphs: a list of binary adjacency matrices
    """
    graphs = []
    r = w.shape[0]
	
    if graph_size == 'vary':
        # numbers = np.linspace(50+offset,300+offset,num_graphs).astype(int).tolist()
        numbers = np.random.randint(50+offset, 350+offset, num_graphs)

    else: # fixed size
        numbers = [num_nodes for _ in range(num_graphs)]
    #print(numbers)
    
    np.random.seed(seed_edge) #add random seed for reproducibility
    for n in range(num_graphs):
        node_locs = (r * np.random.rand(numbers[n])).astype('int')
        graph = w[node_locs, :]
        graph = graph[:, node_locs]
        noise = np.random.rand(graph.shape[0], graph.shape[1])
        graph -= noise
        np.fill_diagonal(graph, 0)
        graphs.append((graph > 0).astype('float'))

    return graphs

def gw_distance(graphon: np.ndarray, estimation: np.ndarray) -> float:
    p = np.ones((graphon.shape[0],)) / graphon.shape[0]
    q = np.ones((estimation.shape[0],)) / estimation.shape[0]
    loss_fun = 'square_loss'
    dw2 = ot.gromov.gromov_wasserstein2(graphon, estimation, p, q, loss_fun, log=False, armijo=False)
    return np.sqrt(dw2)

def coords_prediction(inr_dim_hidden, gnn_dim_hidden, n_epochs, epoch_show, w0, graphs, lr):
    model_inr = SirenNet(dim_in = 2, # input [x,y] coordinate
                dim_hidden = inr_dim_hidden,
                dim_out = 1, # output graphon (edge) probability 
                num_layers = len(inr_dim_hidden), # f_theta number of layers
                final_activation = 'sigmoid',
                w0_initial = w0).to(device)
    
    model_gnn = GCN(1, gnn_dim_hidden, 1).to(device) 
    model_ISGL = ISGL(model_gnn, model_inr).to(device)
    graph_data = [nx2torch2(graph_i) for graph_i in graphs]
    optimizer = torch.optim.Adam(model_ISGL.parameters(), lr=lr)
    criterion = torch.nn.MSELoss()
    model_ISGL.train()
    for epoch in range(1,n_epochs+1):
        random.shuffle(graph_data)
        loss_e = 0
        for data_i in graph_data:  # Iterate in graphs over the training dataset.
            # data_i = nx2torch(graph_i)
            edges_pred, _ = model_ISGL(data_i)  # Perform a single forward pass.
            loss = criterion(edges_pred.squeeze(-1), data_i.y)  # Compute the loss based on smoothing values
            if torch.isnan(loss):
                break
            loss.backward()  # Derive gradients.
            optimizer.step()  # Update parameters based on gradients.
            optimizer.zero_grad()  # Clear gradients.
            loss_e = loss_e + loss.item()
        if epoch==1 or epoch % epoch_show == 0:
            print("epoch: ", epoch, " loss: ", loss_e)


    return model_ISGL, loss_e

def train_graphon(inr_dim_hidden, w0, X_all, y_all, w_all, n_epochs, epoch_show, lr, batch_size, wLoss=0, isparametric=0):
    inr_model = SirenNet(dim_in = 2 + isparametric , # input [x,y] coordinate
                dim_hidden = inr_dim_hidden,
                dim_out = 1, # output graphon (edge) probability 
                num_layers = len(inr_dim_hidden), # f_theta number of layers
                final_activation = 'sigmoid',
                w0_initial = w0).to(device)
    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X_all, y_all, w_all), batch_size=batch_size, shuffle=True)
    optimizer = torch.optim.Adam(inr_model.parameters(), lr=lr)
    criterion = torch.nn.MSELoss(reduction='none')
    for epoch in range(1, n_epochs+1):
        loss_e = 0
        for loader_i, (X, y, w) in enumerate(train_loader):
            X, y, w = X.to(device), y.to(device), w.to(device)
            optimizer.zero_grad()
            y_pred = inr_model(X)
            loss = criterion(y_pred, y)
            weighted_loss = torch.mean(loss * w) if wLoss else torch.mean(loss)
            weighted_loss.backward()
            optimizer.step()
            loss_e = loss_e + weighted_loss.item()
        if epoch==1 or epoch % epoch_show == 0:
            print('Epoch: {}, Loss: {:.5f}'.format(epoch, loss_e))

    return inr_model

def get_graphon(Res, model, gpu=False):
    k = Res
    x_coord = (np.arange(k) + 0.5) / k
    y_coord = (np.arange(k) + 0.5) / k
    xx, yy = np.meshgrid(x_coord, y_coord)
    upper_triangle_mask = np.triu(np.ones((k, k), dtype=bool))
    X_upper_triangle = np.column_stack((xx[upper_triangle_mask], yy[upper_triangle_mask]))
    if gpu:
        X_torch = torch.tensor(X_upper_triangle, dtype=torch.float).to(device)
    else:
        X_torch = torch.tensor(X_upper_triangle, dtype=torch.float).cpu()
    model.eval()
    with torch.no_grad():
        graphon_upper = model(X_torch)
    graphon = np.zeros((k, k))
    graphon[np.triu_indices(k, 0)] = graphon_upper.cpu().numpy().reshape(-1)
    graphon = graphon + graphon.T
    np.fill_diagonal(graphon, 0)
    return graphon

def plot_smaples(true_graph, predictions_graphon, name, model_ISGL, drop):
        model_ISGL.eval()
        data_i = nx2torch2(true_graph)
        _, coords = model_ISGL(data_i)
        perm = torch.argsort(coords.squeeze(-1))
        perm = perm.cpu().numpy()  # Convert torch permutation to numpy array
        true_graph_sorted = true_graph[perm,:][:, perm]
        coords = coords.cpu().detach().numpy()

        # plot the true graph and sorted graph based on the coordinates
        plt.figure(figsize=(15, 5))
        mpl.rcParams['font.family'] = 'serif'
        plt.rc('text', usetex=True)
        plt.rc('text.latex', preamble=r'\usepackage{amsmath}')

        plt.subplot(1, 4, 1)
        plt.spy(true_graph, markersize=1)
        plt.title(r'True graph', fontsize=18)

        plt.subplot(1, 4, 2)
        plt.scatter([np.sum(true_graph[:,j]) for j in range(true_graph.shape[0])], coords)
        plt.xlabel(r'Node degree', fontsize=16)
        plt.ylabel(r'$\hat{\boldsymbol{\eta}}$', fontsize=16)
        plt.title(r'Estimated latent variables', fontsize=18)

        plt.subplot(1, 4, 3)
        plt.spy(true_graph_sorted, markersize=1)
        plt.title(r'True graph sorted', fontsize=18)

        # plt.subplot(1, 5, 4)
        # N = true_graph.shape[0]
        # h = int(np.log(N))
        # A = torch.tensor(true_graph_sorted.toarray(), dtype=torch.float, device=device)
        # A = A.view(1, 1, N, N)
        # H = F.avg_pool2d(A, kernel_size=(h, h))
        # H = H.cpu().numpy().squeeze()
        # plt.imshow(H, cmap='viridis')
        # plt.title(r'(Sorted) histogram approximation', fontsize=18)

        plt.subplot(1, 4, 4)
        plt.imshow(predictions_graphon, cmap='viridis')
        plt.axis('off')
        plt.title(r'Predicted graphon $f_{\theta^*}(x,y)$', fontsize=18)

        plt.tight_layout()
        plt.savefig('Plots/'+ name + "/AllatOnce" + str(drop) + ".jpg", dpi=300)



def plot_edge_confusion_matrices(adj, aug_adj1, aug_adj2, name, drop):
    adj_flat = adj.toarray().flatten()
    aug_adj1_flat = aug_adj1.toarray().flatten()
    aug_adj2_flat = aug_adj2.toarray().flatten()

    aug1_confusion = confusion_matrix(adj_flat, aug_adj1_flat) // 2
    aug2_confusion = confusion_matrix(adj_flat, aug_adj2_flat) // 2

    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # Plot the first confusion matrix
    disp1 = ConfusionMatrixDisplay(aug1_confusion, display_labels=["0", "1"])
    disp1.plot(ax=axes[0], cmap="viridis", values_format='d')
    axes[0].set_title("Augmentation 1")

    # Plot the second confusion matrix
    disp2 = ConfusionMatrixDisplay(aug2_confusion, display_labels=["0", "1"])
    disp2.plot(ax=axes[1], cmap="viridis", values_format='d')
    axes[1].set_title("Augmentation 2")

    plt.tight_layout()
    plt.savefig('Plots/' + name + '/aug_edge_confusion' + str(drop) + '.jpg', dpi=300)



def modify_adjacency_matrix(adj_matrix, another_matrix, edge_ratio):
    # Get the indices of existing edges (value = 1)
    edge_indices = np.argwhere(adj_matrix == 1)
    non_edge_indices = np.argwhere(adj_matrix == 0)

    # Shuffle and select a fraction of the edges
    num_edges_to_modify = int(len(edge_indices) * edge_ratio)
    edge_selection = np.random.choice(len(edge_indices), num_edges_to_modify, replace=False)
    selected_edge_indices = edge_indices[edge_selection]

    # Shuffle and select a fraction of the non-edges
    num_non_edges_to_modify = int(len(non_edge_indices) * edge_ratio)
    non_edge_selection = np.random.choice(len(non_edge_indices), num_non_edges_to_modify, replace=False)
    selected_non_edge_indices = non_edge_indices[non_edge_selection]

    # Modify the adjacency matrix using advanced indexing
    adj_matrix[selected_edge_indices[:, 0], selected_edge_indices[:, 1]] = another_matrix[selected_edge_indices[:, 0], selected_edge_indices[:, 1]]
    adj_matrix[selected_non_edge_indices[:, 0], selected_non_edge_indices[:, 1]] = another_matrix[selected_non_edge_indices[:, 0], selected_non_edge_indices[:, 1]]

    return adj_matrix



def graphon_conv(W, X, h, n):
    """
    Implements the given operation:
    h *W X = sum_{k=0}^{K-1} h_k (T_W^(k) X)
    
    Parameters:
    W : Graphon
        Weight matrix of shape (N, N) representing W(u, v).
    X : features
        Input matrix of shape (N,1) representing X(u).
    h : list or numpy.ndarray
        Coefficients [h_0, h_1, ..., h_{K-1}].

    Returns:
    numpy.ndarray
        The transformed matrix of shape (N, 1).
    """
    K = len(h)
    T_X = [X]  # T_W^(0) X is identity, so it is X itself

    # Compute the successive transformations T_W^(k) X
    for k in range(1, K):
        T_k_X = W @ T_X[-1]  # Integral is replaced by matrix multiplication
        T_k_X = T_k_X / n  # Normalize by the number of nodes
        T_X.append(T_k_X)

    # Weighted sum of transformed versions
    result = sum(h[k] * T_X[k] for k in range(K))
    
    return result



def graphon_layer(W, X_prev, H, n, activation=np.tanh):
    """
    Implements:
    X_{l}^{f} = rho( sum_{g=1}^{F_{l-1}} apply_operator(W, X_{l-1}^{g}, H[f,g,:]) )

    Parameters:
    W : numpy.ndarray
        Weight matrix of shape (N, N) representing W(u, v).
    X_prev : list of numpy.ndarray
        List of F_{l-1} matrices, each of shape (N, 1), representing X_{l-1}^{g}.
    H : numpy.ndarray
        3D array of shape (F_l, F_{l-1}, K), where H[f, g, :] is the coefficient list for h_{l}^{fg}.
    activation : function
        Non-linear activation function (default: tanh).

    Returns:
    list of numpy.ndarray
        List of F_l matrices, each of shape (N, 1), representing X_{l}^{f}.
    """
    F_l, F_l_minus_1, K = H.shape  # Extract shape of H
    X_l = []

    for f in range(F_l):  # Iterate over output features
        sum_term = np.zeros_like(X_prev[0])  # Initialize sum with zeros
        
        for g in range(F_l_minus_1):  # Iterate over input features
            sum_term += graphon_conv(W, X_prev[g], H[f, g, :], n)  # Use the correct slice of H

        X_l.append(activation(sum_term))  # Apply activation function
    
    return X_l


def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def relu(x):
    return np.maximum(0, x)




    



def plot_tsne_2d(X_G, G_Y, seed, ds, aug, eval, ACC):

    tsne = TSNE(n_components=2, random_state=seed)
    X_G_tsne = tsne.fit_transform(X_G)
    
    plt.figure(figsize=(6, 6))

    # Get unique labels and corresponding colors
    unique_labels = np.unique(G_Y)
    cmap = plt.get_cmap('viridis', len(unique_labels))
    colors = cmap(np.linspace(0, 1, len(unique_labels)))

    if eval:
        for i, label in enumerate(unique_labels):
            idx = G_Y == label
            plt.scatter(X_G_tsne[idx, 0], X_G_tsne[idx, 1], 
                        color=colors[i], label=f'Class {label}', alpha=0.7)

        plt.legend(loc='best')
        plt.title("Eval data of " + aug + ' (Acc: ' + str(np.round(ACC,2)) + ')')
        plt.tight_layout()
        plt.grid()
    else:
        for i, label in enumerate(unique_labels):
            idx = G_Y == label
            plt.scatter(X_G_tsne[idx, 0], X_G_tsne[idx, 1], 
                        color=colors[i], label=f'Label {label}', alpha=0.7)

        plt.legend(loc='best')
        plt.title("Training data of " + aug)
        plt.tight_layout()
        plt.grid()

    # check if the directory exists
    if not os.path.exists("Plots/" + ds):
        os.makedirs("Plots/" + ds)
    if eval:
        plt.savefig("Plots/" + ds +  "/" + aug + "_eval.jpg")
    else:
        plt.savefig("Plots/" + ds + "/" + aug + "_PosVsNeg.jpg")




def augSIGL(adj, lr, drop_percent):
    N_adj = adj.shape[0]
    graphs_inr = [adj]
    inr_dim_hidden = [20, 20, 20]
    lr_inr = 0.01
    batch_size_inr = 1024
    gnn_dim_hidden = [8, 8]
    w0 = 10
    num_nodes_all = sum([graph_i.shape[0] for graph_i in graphs_inr])
    model_ISGL, _ = coords_prediction(inr_dim_hidden, gnn_dim_hidden, 100, 100, w0, graphs_inr, lr_inr)
    X_all, y_all, w_all, perm = graph2XY(graphs_inr, num_nodes_all, model_ISGL)
    trained_inr = train_graphon(inr_dim_hidden, w0, X_all, y_all, w_all, 100, 100, lr, batch_size_inr)
    # get the prob for edges
    adj_prob = get_graphon(N_adj, trained_inr)
    # generate symmetric noise for two views
    noise_1 = np.random.rand(*adj_prob.shape)
    noise_1 = np.triu(noise_1)
    noise_1 = noise_1 + noise_1.T - np.diag(np.diag(noise_1))
    noise_2 = np.random.rand(*adj_prob.shape)
    noise_2 = np.triu(noise_2)
    noise_2 = noise_2 + noise_2.T - np.diag(np.diag(noise_2))
    # resample two views of edges
    sample_adj1 = (noise_1 <= adj_prob).astype(np.int32)
    sample_adj2 = (noise_2 <= adj_prob).astype(np.int32)
    inv_perm = np.argsort(perm)
    sample_adj1 = sample_adj1[inv_perm, :][:, inv_perm]
    sample_adj2 = sample_adj2[inv_perm, :][:, inv_perm]

    if drop_percent == 1.0:
        aug_adj1 = sample_adj1
        aug_adj2 = sample_adj2
        aug_adj1 = sp.csr_matrix(aug_adj1)
        aug_adj2 = sp.csr_matrix(aug_adj2)
    else:
        aug_adj1_gr = adj.copy()
        aug_adj2_gr = adj.copy()
        # if seperate_edges:
        #     # aug_adj1_gr = modify_adjacency_matrix(aug_adj1_gr, sample_adj1, drop_percent)
        #     # aug_adj2_gr = modify_adjacency_matrix(aug_adj2_gr, sample_adj2, drop_percent)
        # else:
        row_idx, col_idx = np.triu_indices(N_adj, k=1)
        num_possible_edges = row_idx.shape[0]
        num_selected_edges = int(num_possible_edges * drop_percent)
        selected_indices = np.random.choice(num_possible_edges, num_selected_edges, replace=False)
        selected_row_idx = row_idx[selected_indices]
        selected_col_idx = col_idx[selected_indices]
        aug_adj1_gr[selected_row_idx, selected_col_idx] = sample_adj1[selected_row_idx, selected_col_idx]
        aug_adj1_gr[selected_col_idx, selected_row_idx] = sample_adj1[selected_row_idx, selected_col_idx]
        aug_adj2_gr[selected_row_idx, selected_col_idx] = sample_adj2[selected_row_idx, selected_col_idx]
        aug_adj2_gr[selected_col_idx, selected_row_idx] = sample_adj2[selected_row_idx, selected_col_idx]

    return aug_adj1_gr, aug_adj2_gr, trained_inr, model_ISGL





def getR(c_l, y_l, isMGCL=True):
    # get inputs on cpu
    print(c_l)
    print(y_l)

    R_batch = []
    for i in range(len(c_l)):
        c_i = c_l[i]
        y_i = y_l[i]
        TN_i = 0
        FN_i = 0
        for j in range(len(c_l)):
            if j != i:
                c_j = c_l[j]
                y_j = y_l[j]
                if isMGCL:
                    if c_i != c_j:
                        if y_j == y_i:
                            FN_i += 1
                        else:
                            TN_i += 1
                else:
                    if y_j == y_i:
                        FN_i += 1
                    else:
                        TN_i += 1
        
        R_i = (TN_i) / (FN_i) if FN_i != 0 else np.nan
        # R_i = np.exp(TN_i/FN_i)
        R_batch.append(R_i)
    # replace nan with max value
    R_batch = np.array(R_batch)
    R_batch[np.isnan(R_batch)] = np.nanmax(R_batch)
    # change R_batch to list
    R_batch = R_batch.tolist()
    return R_batch
