import torch
# from torchvision import datasets, transforms
from scipy.io import loadmat, savemat
import numpy as np
import random
import torch.nn.functional as F

def load_calib_dataset(args, data_dir='./data'):
    if args.dataset == "MNIST":
        dataloader = torch.utils.data.DataLoader(
                        datasets.MNIST(data_dir, train=True, download=True,
                                    transform=transforms.Compose([
                                        transforms.ToTensor()
                                    ])),
                        batch_size=args.calib_samples, shuffle=True)
        input_of_sparse_layer = np.zeros((784,60000))
    elif args.dataset == "Fashion_MNIST":
        dataloader= torch.utils.data.DataLoader(datasets.FashionMNIST(
                    root=data_dir,
                    train=True,
                    transform=transforms.Compose([
                        transforms.ToTensor()
                        # transforms.Normalize((0.1307,), (0.3081,))
                    ]),
                    download=True),
                    batch_size=args.calib_samples,
                    shuffle=True)
        input_of_sparse_layer = np.zeros((784,60000))
    elif args.dataset == "EMNIST":
        dataloader = torch.utils.data.DataLoader(datasets.EMNIST(
                    root=data_dir,
                    train=True,
                    transform=transforms.Compose([
                        transforms.ToTensor()
                    ]),
                    download=True,
                    split='balanced'),
                    batch_size=args.calib_samples,
                    shuffle=True)
        input_of_sparse_layer = np.zeros((784,112800))
    return dataloader, input_of_sparse_layer

def rewire_connections(layer):
    new_matrix = torch.zeros_like(layer.weight_mask).to(layer.device)

    cols = new_matrix.shape[1]
    rows = new_matrix.shape[0]

    for i in range(cols):
        column = layer.weight_mask[:, i]

        num_connections = column.nonzero().numel()
        

        new_positions = torch.randperm(rows)[:num_connections]

        new_matrix[new_positions, i] = column[column != 0]

    layer.weight_mask = new_matrix
                


def create_sparse_topological_initialization(args, model, filename=None, eng=None):
    
    if args.self_correlated_sparse:
        dataloader, input_of_sparse_layer = load_calib_dataset(args, data_dir='./data')

        print("Using self correlated sparse of mlp!!!")
        
        import os
        if os.path.exists(filename):
            corr = loadmat(filename + "/corr.mat")["corr"]
        else:
            for batch_idx, (data, _) in enumerate(dataloader):
                input_of_sparse_layer[:,batch_idx*args.calib_samples:batch_idx*args.calib_samples + data.shape[0]] = data.reshape(-1, 784).numpy().transpose(1, 0)
            corr = np.corrcoef(input_of_sparse_layer)
            os.makedirs(filename)
            print("done")
            
            savemat(filename + "/corr.mat", {"corr":corr})
        create_self_correlated_sparse(model, corr, args.dim, args.soft_csti, args.noise_csti)
    
    elif args.BA:
        for i, layer in enumerate(model.sparse_layers):
            create_ba_sparse(layer) 
            if i == 0 and args.rewire_first_layer:
                rewire_connections(layer)
            

    elif args.WS:
        for layer in model.sparse_layers:
            create_ws_sparse(layer, args)
            
                
def soft_resort(in_neuron_degree):
    sampled_indices = torch.multinomial(in_neuron_degree, num_samples=in_neuron_degree.shape[0], replacement=False)
    return sampled_indices


def create_ws_sparse(layer, args):
    indim = min(layer.indim, layer.outdim)
    outdim = max(layer.indim, layer.outdim)
    K = (1- layer.sparsity) * indim * outdim / (indim + outdim)
    
    K1 = int(K)
    K2 = int(K) + 1
    dim = max(outdim, indim)
    my_list = [K1] * int(dim * (K2 - K)) + [K2] * int(dim * (K-K1) + 1)
    random.shuffle(my_list)
    
    adj = np.zeros((indim, outdim))

    rate = outdim/indim
    for i in range(indim):
        idx = [(int(i*rate) + j) % outdim for j in range(my_list[i])]
        adj[i, idx] = 1 
    rate = indim/outdim
    random.shuffle(my_list)
    for i in range(outdim):
        idx = [(int(i*rate) + j + 1) % indim for j in range(my_list[i])]
        adj[idx, i] = 1 
        
    # rewiring
    if args.ws_beta != 0:
        randomness = np.random.binomial(1, p=args.ws_beta, size=int(np.sum(adj)))
        # print(randomness)
        count = 0
        for i in range(indim):
            for j in range(outdim):
                if adj[i][j] == 1:
                    if randomness[count] == 1:
                        adj[i][j] = 0
                    
                    count += 1
        
        # regrow
        noRewires = int((1-layer.sparsity) * indim * outdim) - np.sum(adj)
        nrAdd = 0
        while (nrAdd < noRewires):
            i = np.random.randint(0, indim)
            j = np.random.randint(0, outdim)
            if adj[i][j] == 0:
                nrAdd += 1
                adj[i][j] = 1
        
        print(np.sum(adj), noRewires)

    if layer.indim != indim:
        layer.weight_mask = torch.Tensor(adj).to(layer.device).t()
    else:
        layer.weight_mask = torch.Tensor(adj).to(layer.device)

def symmetric_positions(center, D_total, window_size, N_in):
    """
    Compute D_total positions evenly spaced within a window of length window_size,
    symmetric about the given center.
    """
    half_window = (window_size - 1) // 2  # Ensure integer window sizes
    # Generate D_total evenly spaced positions in [center-half_window, center+half_window]
    positions = np.linspace(center - half_window, center + half_window, D_total, dtype=int)

    # Apply modulo N_in to handle wraparound at the edges
    positions = [(x % N_in) for x in positions]

    # Ensure no duplicate positions
    unique_positions = list(set(positions))

    # If the number of unique positions is less than D_total, repeat the positions until we reach D_total
    while len(unique_positions) < D_total:
        unique_positions.extend(unique_positions)  # Replicate positions to fill the gap
        unique_positions = unique_positions[:D_total]  # Limit to D_total positions

    return unique_positions


def get_closest_nodes_centered(i, N, count):
    """
    Returns a list of 'count' indices from a layer of size N,
    ordered as: i, i-1, i+1, i-2, i+2, … (with modulo wrapping)
    """
    indices = [i]
    d = 1
    while len(indices) < count:
        indices.append((i - d) % N)
        if len(indices) < count:
            indices.append((i + d) % N)
        d += 1
    return indices[:count]

def pick_groups_for_output_node(j, N_in, N_out, D, M, gamma):
    """
    For a given output node with index j, select M groups of contiguous input nodes
    (from an input layer of size N_in) to which this output node will connect.
    
    The output node will connect to a total of D input nodes, divided into M groups 
    (each of size group_size = D // M). The selection is deterministic:
    
      - gamma = 0  -> Groups are chosen from the most local (adjacent) region.
      - gamma = 1  -> Groups are spread uniformly over the entire input layer.
      - 0 < gamma < 1 -> The window from which groups are selected is linearly 
                         interpolated between the minimal (local) window and the full input.
    
    Args:
      j    : Index of the output node.
      N_in : Total number of nodes in the input layer.
      D    : Total number of input nodes to connect (must be divisible by M).
      M    : Number of groups.
      gamma: Spread parameter (0 = local, 1 = global).
    
    Returns:
      A list of M groups, where each group is a list of contiguous input node indices.
      (Indices are computed modulo N_in to wrap around.)
    """
    group_size = D // M            # Each group has D/M nodes.
    minimal_window = M * group_size  # The smallest window that can hold all groups locally.
    
    # Compute the window size: for gamma=0, the window is minimal; for gamma=1, the window is the entire input.
    if N_in > minimal_window:
        window_size = int(round(minimal_window + gamma * (N_in - minimal_window)))
    else:
        window_size = N_in

    # Center the window around the output node index j.
    center = int(j * (N_in / N_out))
    window_start = center - window_size // 2

    groups = []
    if M > 1:
        # Determine how much room we have to shift the groups inside the window.
        avail_range = window_size - group_size
        for m in range(M):
            # For group m, compute an offset such that the groups are evenly distributed over the available range.
            offset = int(np.floor(m * avail_range / (M - 1)))
            start = (window_start + offset) % N_in
            group = [(start + k) % N_in for k in range(group_size)]
            groups.append(group)
    else:
        # If there's only one group, simply center it.
        start = (j - group_size // 2) % N_in
        group = [(start + k) % N_in for k in range(group_size)]
        groups.append(group)
    
    return groups

def pick_connections_for_output_node(j, N_in, N_out, sparsity, D_total, M, gamma, args):
    """
    For output node j, returns a list of input node indices that will be connected.
    """
    # Calculate base center in input space
    base_center = j * (N_in / N_out)
    center = int(round(base_center)) % N_in

    # Determine window size
    minimal_window = D_total
    if getattr(args, "exp_window", False):
        exp_scale = getattr(args, "exp_window_scale", 400)  # Smaller scale = sharper falloff
        sampled_inc = int(np.random.exponential(scale=exp_scale))
        
        # Ensure window_size is between minimal_window and N_in
        window_size = minimal_window + sampled_inc
        window_size = min(window_size, N_in)
        window_size = max(window_size, minimal_window)  # Ensure at least minimal_window
    else:
        window_size = int(round(minimal_window + gamma * (N_in - minimal_window))) if N_in > minimal_window else N_in

    # Calculate group centers (interpolated between local and uniform)
    window_start = center - window_size // 2
    uniform_centers = np.linspace(window_start, window_start + window_size, M, endpoint=False) + window_size/(2*M)
    group_centers = (1 - gamma) * center + gamma * uniform_centers
    group_centers = [int(round(gc)) % N_in for gc in group_centers]
    if args.synaptic_dist == "uniform":
        if gamma == 0:
            return symmetric_positions(center, D_total, window_size, N_in)
        else:
        # Uniform distribution (original behavior)
            base = D_total // M
            remainder = D_total % M
            group_sizes = [base] * M
            for idx in range(remainder):
                group_sizes[idx] += 1
            if gamma == 0:
                return symmetric_positions(center, D_total, window_size, N_in)

        connections = []

        for gc, g_size in zip(group_centers, group_sizes):
            c = int(round(gc)) % N_in
            # Get a contiguous block of g_size neurons centered at c.
            half = g_size // 2
            if g_size % 2 == 1:
                group = [(c - half + k) % N_in for k in range(g_size)]
            else:
                group = [(c - half + k) % N_in for k in range(g_size)]
            connections.extend(group)
            
        return connections

    elif args.synaptic_dist == "distance_based":
        # Calculate distances from base_center to group centers
        distances = []
        for gc in group_centers:
            d_abs = abs(gc - base_center)
            d_circular = min(d_abs, N_in - d_abs)
            distances.append(d_circular)

        # Compute probabilities using softmax with temperature
        temp = getattr(args, "distance_temp", 0.5)  
        # print(temp)
        epsilon = 1e-10  # To avoid division by zero
        scaled_distances = [-d / (temp * N_in + epsilon) for d in distances]  # Negative distances (closer = higher value)
        exp_dist = np.exp(scaled_distances - np.max(scaled_distances))  # Numerical stability
        probs = exp_dist / np.sum(exp_dist)

        # Probabilistically assign connections to groups
        group_assignments = np.random.choice(M, size=D_total, p=probs)
        group_sizes = np.bincount(group_assignments, minlength=M)

        # Expand each group from center outward
        connections = []
        for gc, size in zip(group_centers, group_sizes):
            if size == 0:
                continue  # Skip groups with zero connections
            group = []
            step = 0
            direction = -1  # Start expanding left
            while len(group) < size:
                if step == 0:
                    pos = gc % N_in  # Center position
                else:
                    pos = (gc + direction * step) % N_in  # Expand left/right
                if pos not in group:
                    group.append(pos)
                if len(group) >= size:
                    break
                direction *= -1  # Alternate direction
                if direction == 1:
                    step += 1
            connections.extend(group)

        # Deduplicate and pad if needed
        unique_connections = list(set(connections))
        if len(unique_connections) < D_total:
            padding = get_closest_nodes_centered(int(base_center), N_in, D_total - len(unique_connections))
            unique_connections.extend(padding)
        unique_connections = unique_connections[:D_total]
        return unique_connections


def create_dendritic_sparse_scheduler(sparsity, w, args):
    # Determine dimensions:
    N_in = min(w.shape[0], w.shape[1])
    N_out = max(w.shape[0], w.shape[1])
    M = args.M

    # Calculate D_float and K1, K2 as before
    D_float = N_in * (1 - sparsity)
    K1 = int(D_float)          # Floor of the target connections
    K2 = K1 + 1                # Ceiling of the target connections

    # Calculate the number of nodes that will get K1 and K2 connections
    total_connections = N_out * D_float  # Total number of connections needed
    count_K1 = int(total_connections - (K2 - D_float) * N_out)  # More exact formula for count_K1
    count_K2 = N_out - count_K1        # The rest will get K2

    # Assign K1 and K2 connections uniformly, avoiding random shuffle
    connection_counts = [K1] * count_K1 + [K2] * count_K2
    random.shuffle(connection_counts)

    # Initialize the adjacency matrix (rows: input nodes, columns: output nodes)
    adj = np.zeros((N_in, N_out), dtype=int)

    # If exp_window is provided, set gamma to 1 and gamma_dist to "fixed"
    if getattr(args, "exp_window", False):
        gamma = 1  # Set gamma to 1
        gamma_dist = "fixed"  # Set gamma_dist to fixed
    else:
        # Else, follow user input for gamma and gamma_dist
        gamma = args.gamma if args.gamma is not None else 0.5
        gamma_dist = args.gamma_dist if args.gamma_dist is not None else "fixed"

    # For each output node, select its input connections.
    for j in range(N_out):
        # Determine gamma for this node as before
        if gamma_dist == "fixed":
            gamma_j = gamma  # fixed gamma
        elif gamma_dist == "gaussian":
            mean_gamma = gamma
            gamma_std = getattr(args, "gamma_std", 0.1)
            gamma_j = np.clip(np.random.normal(mean_gamma, gamma_std), 0, 1)
        elif gamma_dist == "uniform":
            gamma_j = np.random.uniform(0, 1)
        else:
            raise ValueError("Unknown gamma distribution: {}".format(gamma_dist))

        D_total = connection_counts[j]  # integer number of connections for this node

        # Now, define minimal_window (window size to accommodate all groups)
        group_size = D_total // M
        minimal_window = M * group_size  # The smallest window that can hold all groups locally
        
        # If exp_window is enabled, sample the window size from an exponential distribution
        if getattr(args, "exp_window", False):
            exp_scale = getattr(args, "exp_window_scale", 500)  # Default to 500 if not provided
            window_size = int(np.random.exponential(scale=exp_scale))
            window_size = min(window_size, N_in)  # Ensure the window doesn't exceed the total number of nodes
        else:
            window_size = int(round(minimal_window + gamma_j * (N_in - minimal_window))) if N_in > minimal_window else N_in

        # Get the connections for this output node
        connections = pick_connections_for_output_node(
            j, N_in, N_out, sparsity, D_total, M, gamma_j, args
        )

        for i in connections:
            adj[i, j] = 1

    # --- REWIRING (unchanged) ---
    if args.ws_beta != 0:
        total_edges = int(np.sum(adj))
        randomness = np.random.binomial(1, p=args.ws_beta, size=total_edges)
        count = 0
        for i in range(N_in):
            for j in range(N_out):
                if adj[i, j] == 1:
                    if randomness[count] == 1:
                        adj[i, j] = 0  # remove edge
                    count += 1

        removed_edges = total_edges - int(np.sum(adj))
        nrAdd = 0
        while nrAdd < removed_edges:
            i_rand = np.random.randint(0, N_in)
            j_rand = np.random.randint(0, N_out)
            if adj[i_rand, j_rand] == 0:
                adj[i_rand, j_rand] = 1
                nrAdd += 1
        print("After rewiring, total edges:", np.sum(adj), "removed:", removed_edges)

    if w.shape[0] != N_in:
        return torch.LongTensor(adj).to(w.device).t()
    return torch.LongTensor(adj).to(w.device)




def create_ws_sparse_scheduler(sparsity, w, args):
    indim = min(w.shape[0], w.shape[1])
    outdim = max(w.shape[0], w.shape[1])
    K = (1- sparsity) * indim * outdim / (indim + outdim)
    
    K1 = int(K)
    K2 = int(K) + 1
    dim = max(outdim, indim)
    my_list = [K1] * int(dim * (K2 - K)) + [K2] * int(dim * (K-K1) + 1)
    random.shuffle(my_list)
    
    adj = np.zeros((indim, outdim))

    rate = outdim/indim
    for i in range(indim):
        idx = [(int(i*rate) + j) % outdim for j in range(my_list[i])]
        adj[i, idx] = 1 
    rate = indim/outdim
    random.shuffle(my_list)
    for i in range(outdim):
        idx = [(int(i*rate) + j + 1) % indim for j in range(my_list[i])]
        adj[idx, i] = 1 
        
    # rewiring
    if args.ws_beta != 0:
        randomness = np.random.binomial(1, p=args.ws_beta, size=int(np.sum(adj)))
        # print(randomness)
        count = 0
        for i in range(indim):
            for j in range(outdim):
                if adj[i][j] == 1:
                    if randomness[count] == 1:
                        adj[i][j] = 0
                    
                    count += 1
        
        # regrow
        noRewires = int((1-sparsity) * indim * outdim) - np.sum(adj)
        nrAdd = 0
        while (nrAdd < noRewires):
            i = np.random.randint(0, indim)
            j = np.random.randint(0, outdim)
            if adj[i][j] == 0:
                nrAdd += 1
                adj[i][j] = 1
        
        print(np.sum(adj), noRewires)
    if w.shape[0] != indim:
        return torch.LongTensor(adj).to(w.device).t()

    return torch.LongTensor(adj).to(w.device)
    # layer.weight_mask = torch.LongTensor(adj).to(layer.device)




def generate_barabasi_alberta_graph(N, m):
    adj = np.zeros((N, N))
    if not isinstance(m, int):
        print("m is not an integer")
        m1 = int(m)
        m2 = int(m) + 1
        
        adj[:m2, :m2] = np.triu(np.ones((m2, m2)), k=1) + np.triu(np.ones((m2, m2)), k=1).T
        my_list = [m1] * int((N-m2) * (m2 - m)) + [m2] * int((N-m2) * (m-m1) + 1)
        random.shuffle(my_list)
        for i in range(m2, N):
            targets = np.arange(i)
            p_normalized = np.sum(adj[:i, :i], axis=1) / np.sum(np.sum(adj[:i, :i], axis=1))

            
            m_tmp = my_list[i-m2]
            idx = np.random.choice(targets, size=m_tmp, replace=False, p=p_normalized)
            adj[i, idx] = 1
            adj[idx, i] = 1
            
    return adj 

def create_ba_sparse(layer):
    m = (1- layer.sparsity) * layer.indim * layer.outdim / (layer.indim + layer.outdim)
    adj = generate_barabasi_alberta_graph(layer.indim + layer.outdim, m)
    nodes = list(range(layer.indim + layer.outdim))
    random.shuffle(nodes)
    
    layer_N = list(set(nodes[:layer.indim]))
    layer_M = list(set(nodes[layer.indim:]))

    adj = adj[layer_N+layer_M]
    adj = adj[:, layer_N+layer_M]
    
    adj = np.triu(adj, k=1)
    final_adj = adj[:layer.indim, layer.indim:]
    layer1_fru = np.array(adj[:layer.indim, :layer.indim].nonzero()).reshape(-1)
    layer2_fru = np.array(adj[layer.indim:, layer.indim:].nonzero()).reshape(-1)
    np.random.shuffle(layer1_fru)
    np.random.shuffle(layer2_fru)

    if len(layer1_fru) <= len(layer2_fru):
        # print("in")
        layer2_fru_flag = np.zeros_like(layer2_fru)
        for i in range(len(layer1_fru)):
            for j in range(len(layer2_fru)):
                if layer2_fru_flag[j] == 0:
                    if final_adj[layer1_fru[i], layer2_fru[j]]:
                        continue
                    else:
                        final_adj[layer1_fru[i], layer2_fru[j]] = 1
                        layer2_fru_flag[j] = 1
                        break
                else:
                    continue
             
        zero_indices = np.where(layer2_fru_flag == 0)[0]
        layer2_r = layer2_fru[zero_indices]
        for i in range(len(layer2_r)//2):
            node_degrees = np.sum(final_adj, axis=1)
            prob_distribution = node_degrees / np.sum(node_degrees)
            while True:
                target_node = np.random.choice(np.arange(layer.indim), size=1, replace=False, p=prob_distribution)
                if final_adj[target_node, layer2_r[i]] == 1:
                    continue
                else:
                    final_adj[target_node, layer2_r[i]] = 1
                    break
                
    elif len(layer1_fru) > len(layer2_fru):
        layer1_fru_flag = np.zeros_like(layer1_fru)
        for i in range(len(layer2_fru)):
            for j in range(len(layer1_fru)):
                if layer1_fru_flag[j] == 0:
                    if final_adj[layer1_fru[j], layer2_fru[i]]:
                        continue
                    else:
                        final_adj[layer1_fru[j], layer2_fru[i]] = 1
                        layer1_fru_flag[j] = 1
                        break
                else:
                    continue
        zero_indices = np.where(layer1_fru_flag == 0)[0]
        layer1_r = layer1_fru[zero_indices]
        for i in range(len(layer1_r)//2):
            node_degrees = np.sum(final_adj, axis=0)
            prob_distribution = node_degrees / np.sum(node_degrees)
            while True:
                target_node = np.random.choice(np.arange(layer.outdim), size=1, replace=False, p=prob_distribution)
                if final_adj[layer1_r[i], target_node] == 1:
                    continue
                else:
                    final_adj[layer1_r[i], target_node] = 1
                    break
    # print(int(np.sum(final_adj)))
    layer.weight_mask = torch.Tensor(final_adj).to(layer.device)
    
def create_self_correlated_sparse(model, corr, dim, soft=False, noise=False):
    isnan = np.isnan(corr)
    corr[isnan] = 0
    for i in range(corr.shape[0]):
        corr[i, i] = 0
    
    if noise:
        corr += np.random.randn(corr.shape[0], corr.shape[1])
    # 1x of the dimension
    if dim == 1:
        for i in range(len(model.sparse_layers)):
            number_of_links = model.sparse_layers[i].n_params
            update_topology(model.sparse_layers[i], corr, number_of_links, soft)

    # 2x of the dimension
    elif dim == 2:
        dimension = corr.shape[0] * 2
        expanded_dimension = np.zeros((dimension, dimension))
        expanded_dimension[:dimension//2, :dimension//2] = corr
        expanded_dimension[:dimension//2, dimension//2:] = corr
        expanded_dimension[dimension//2:, :dimension//2] = corr
        expanded_dimension[dimension//2:, dimension//2:] = corr
        
        for i in range(len(model.sparse_layers)):
            number_of_links = model.sparse_layers[i].n_params
            if i == 0:
                first_layer = expanded_dimension[:dimension//2, :].copy()
                update_topology(model.sparse_layers[i], first_layer, number_of_links, soft)
            else:
                update_topology(model.sparse_layers[i], expanded_dimension, number_of_links, soft)
    
def update_topology(layer, corr, number_of_links, soft=False):
    adj = torch.zeros_like(torch.Tensor(corr))
    corr_flatten = torch.abs(torch.Tensor(corr).flatten())
    if soft:
        probabilities = corr_flatten / corr_flatten.sum()
        sampled_flat_indices = torch.multinomial(probabilities, number_of_links, replacement=False)
        adj = adj.reshape(-1)
        adj[sampled_flat_indices] = 1
        adj = adj.reshape(corr.shape[0], corr.shape[1])
    else:
        threshold = torch.abs(torch.sort(-torch.abs(corr_flatten))[0][number_of_links-1])
        corr = torch.Tensor(corr)
        adj[torch.abs(corr)>=threshold]=1
        adj[torch.abs(corr)<threshold]=0

    layer.weight_mask = adj.to(layer.device)


def create_self_correlated_scheduler(model, corr, dim):
    isnan = np.isnan(corr)
    corr[isnan] = 0
    for i in range(corr.shape[0]):
        corr[i, i] = 0

    # 1x of the dimension
    if dim == 1:
        for i in range(len(model.sparse_layers)):
            number_of_links = model.sparse_layers[i].n_params
            update_topology(model.sparse_layers[i], corr, number_of_links)

    # 2x of the dimension
    elif dim == 2:
        dimension = corr.shape[0] * 2
        expanded_dimension = np.zeros((dimension, dimension))
        expanded_dimension[:dimension//2, :dimension//2] = corr
        expanded_dimension[:dimension//2, dimension//2:] = corr
        expanded_dimension[dimension//2:, :dimension//2] = corr
        expanded_dimension[dimension//2:, dimension//2:] = corr
        
        for i in range(len(model.sparse_layers)):
            number_of_links = model.sparse_layers[i].n_params
            if i == 0:
                first_layer = expanded_dimension[:dimension//2, :].copy()
                update_topology(model.sparse_layers[i], first_layer, number_of_links)
            else:
                update_topology(model.sparse_layers[i], expanded_dimension, number_of_links)

def update_topology_scheduler(w, corr, number_of_links):
    adj = torch.zeros_like(torch.Tensor(corr))
    corr_flatten = torch.abs(torch.Tensor(corr).flatten())
    
    threshold = torch.abs(torch.sort(-torch.abs(corr_flatten))[0][number_of_links])
    corr = torch.Tensor(corr)
    adj[torch.abs(corr)>=threshold]=1
    adj[torch.abs(corr)<threshold]=0
    # print(number_of_links)
    # print(torch.sum(adj))
    return adj.to(w.device)

def update_topology_scheduler_soft(w, corr, number_of_links):
    corr = torch.Tensor(corr)
    adj = torch.zeros_like(corr)
    corr_abs = torch.abs(corr)
    corr_flatten = corr_abs.flatten()
    probabilities = corr_flatten / corr_flatten.sum()
    sampled_indices = torch.multinomial(probabilities, number_of_links, replacement=False)
    adj_flatten = adj.flatten()
    adj_flatten[sampled_indices] = 1
    return adj_flatten.view(corr.shape).to(w.device)