import torch
# from torchvision import datasets, transforms
from scipy.io import loadmat, savemat
import numpy as np
import random
from scipy import sparse
import torch.nn.functional as F

def load_calib_dataset(args, data_dir='./data'):
    if args.dataset == "MNIST":
        dataloader = torch.utils.data.DataLoader(
                        datasets.MNIST(data_dir, train=True, download=True,
                                    transform=transforms.Compose([
                                        transforms.ToTensor()
                                    ])),
                        batch_size=args.calib_samples, shuffle=True)
        input_of_sparse_layer = np.zeros((784,60000))
    elif args.dataset == "Fashion_MNIST":
        dataloader= torch.utils.data.DataLoader(datasets.FashionMNIST(
                    root=data_dir,
                    train=True,
                    transform=transforms.Compose([
                        transforms.ToTensor()
                        # transforms.Normalize((0.1307,), (0.3081,))
                    ]),
                    download=True),
                    batch_size=args.calib_samples,
                    shuffle=True)
        input_of_sparse_layer = np.zeros((784,60000))
    elif args.dataset == "EMNIST":
        dataloader = torch.utils.data.DataLoader(datasets.EMNIST(
                    root=data_dir,
                    train=True,
                    transform=transforms.Compose([
                        transforms.ToTensor()
                    ]),
                    download=True,
                    split='balanced'),
                    batch_size=args.calib_samples,
                    shuffle=True)
        input_of_sparse_layer = np.zeros((784,112800))
    return dataloader, input_of_sparse_layer

def rewire_connections(layer):
    new_matrix = torch.zeros_like(layer.weight_mask).to(layer.device)

    cols = new_matrix.shape[1]
    rows = new_matrix.shape[0]

    for i in range(cols):
        column = layer.weight_mask[:, i]

        num_connections = column.nonzero().numel()
        

        new_positions = torch.randperm(rows)[:num_connections]

        new_matrix[new_positions, i] = column[column != 0]

    layer.weight_mask = new_matrix
                


def create_sparse_topological_initialization(args, model, filename=None, eng=None):
    
    if args.self_correlated_sparse:
        dataloader, input_of_sparse_layer = load_calib_dataset(args, data_dir='./data')

        print("Using self correlated sparse of mlp!!!")
        
        import os
        if os.path.exists(filename):
            corr = loadmat(filename + "/corr.mat")["corr"]
        else:
            for batch_idx, (data, _) in enumerate(dataloader):
                input_of_sparse_layer[:,batch_idx*args.calib_samples:batch_idx*args.calib_samples + data.shape[0]] = data.reshape(-1, 784).numpy().transpose(1, 0)
            corr = np.corrcoef(input_of_sparse_layer)
            os.makedirs(filename)
            print("done")
            
            savemat(filename + "/corr.mat", {"corr":corr})
        create_self_correlated_sparse(model, corr, args.dim, args.soft_csti, args.noise_csti)
    
    elif args.BA:
        for i, layer in enumerate(model.sparse_layers):
            create_ba_sparse(layer) 
            if i == 0 and args.rewire_first_layer:
                rewire_connections(layer)
            

    elif args.WS:
        for layer in model.sparse_layers:
            create_ws_sparse(layer, args)
            
                
def soft_resort(in_neuron_degree):
    sampled_indices = torch.multinomial(in_neuron_degree, num_samples=in_neuron_degree.shape[0], replacement=False)
    return sampled_indices


def create_ws_sparse(layer, args):
    indim = min(w.shape[0], w.shape[1])
    outdim = max(w.shape[0], w.shape[1])

    K = (1- args.sparsity) * indim
    
    K1 = int(K)
    K2 = int(K) + 1
    my_list = [K1] * int(outdim * (K2 - K)) + [K2] * int(outdim * (K-K1) + 1)
    random.shuffle(my_list)
    
    adj = np.zeros((indim, outdim))

    rate = indim/outdim
    random.shuffle(my_list)
    for i in range(outdim):
        idx = [(int(i*rate-my_list[i]/2) + j) % indim for j in range(my_list[i])]
        adj[idx, i] = 1 
        
    # rewiring
    if args.random_rewiring != 0:
        randomness = np.random.binomial(1, p=args.random_rewiring, size=int(np.sum(adj)))
        count = 0
        for i in range(indim):
            for j in range(outdim):
                if adj[i][j] == 1:
                    if randomness[count] == 1:
                        adj[i][j] = 0
                    
                    count += 1
        
        # regrow
        noRewires = int((1-layer.sparsity) * indim * outdim) - np.sum(adj)
        nrAdd = 0
        while (nrAdd < noRewires):
            i = np.random.randint(0, indim)
            j = np.random.randint(0, outdim)
            if adj[i][j] == 0:
                nrAdd += 1
                adj[i][j] = 1
        
        print(np.sum(adj), noRewires)

    if layer.indim != indim:
        layer.weight_mask = torch.Tensor(adj).to(layer.device).t()
    else:
        layer.weight_mask = torch.Tensor(adj).to(layer.device)


def _adjust_samples(samples, target_total):
    """Convert float samples to integers using probabilistic rounding"""
    # Calculate integer parts and fractional remainders
    integer_parts = np.floor(samples).astype(int)
    fractional = samples - integer_parts
    
    # Calculate how many connections we need to add
    total_integer = np.sum(integer_parts)
    remainder = target_total - total_integer
    
    # Probabilistically distribute remaining connections
    if remainder > 0:
        # Get probabilities from fractional parts
        probs = fractional / np.sum(fractional)
        # Randomly choose which indices get extra connection
        extra_indices = np.random.choice(len(samples), size=remainder, p=probs, replace=True)
        np.add.at(integer_parts, extra_indices, 1)
    
    return integer_parts.tolist()


def symmetric_positions(center, D_total, window_size, N_in):
    """
    Compute D_total positions evenly spaced within a window of length window_size,
    symmetric about the given center.
    """
    half_window = (window_size - 1) // 2
    positions = np.linspace(center - half_window, center + half_window, D_total, dtype=int)

    positions = [(x % N_in) for x in positions]

    unique_positions = list(set(positions))

    while len(unique_positions) < D_total:
        unique_positions.extend(unique_positions)  
        unique_positions = unique_positions[:D_total] 

    return unique_positions


def get_closest_nodes_centered(i, N, count):
    """
    Returns a list of 'count' indices from a layer of size N,
    ordered as: i, i-1, i+1, i-2, i+2, … (with modulo wrapping)
    """
    indices = [i]
    d = 1
    while len(indices) < count:
        indices.append((i - d) % N)
        if len(indices) < count:
            indices.append((i + d) % N)
        d += 1
    return indices[:count]


def pick_connections_for_output_node(j, N_in, N_out, sparsity, D_total, M, gamma, args):
    """
    For output node j, returns a list of input node indices that will be connected.
    """
    base_center = j * (N_in / N_out)
    center = int(round(base_center)) % N_in

    base_window = D_total
    window_size = int(round(base_window + gamma * (N_in - base_window)))
    window_size = max(D_total, min(window_size, N_in))

    window_start = center - window_size // 2
    uniform_centers = np.linspace(window_start, window_start + window_size, M, endpoint=False) + window_size/(2*M)
    group_centers = (1 - gamma) * center + gamma * uniform_centers
    group_centers = [int(round(gc)) % N_in for gc in group_centers]
    
    if args.synaptic_dist == "fixed":
        if gamma == 0:
            return symmetric_positions(center, D_total, window_size, N_in)
        else:
            # Partition D_total over M groups
            base = D_total // M
            remainder = D_total % M
            group_sizes = [base] * M
            extra_indices = random.sample(range(M), remainder)
            for idx in extra_indices:
                group_sizes[idx] += 1

            connections = []
            for gc, g_size in zip(group_centers, group_sizes):
                c = int(round(gc)) % N_in
                half = g_size // 2
                # Pick g_size contiguous nodes around c
                group = [(c - half + k) % N_in for k in range(g_size)]
                connections.extend(group)
            
            # Force uniqueness
            unique_connections = list(set(connections))
            if len(unique_connections) < D_total:
                # Instead of padding from a global nearest list, sample additional nodes from the designated window.
                possible_window = [(window_start + k) % N_in for k in range(window_size)]
                # Exclude what we already have.
                available = list(set(possible_window) - set(unique_connections))
                needed = D_total - len(unique_connections)
                # If the window does not contain enough nodes, sample from the whole range without duplicates.
                if len(available) < needed:
                    available = list(set(range(N_in)) - set(unique_connections))
                additional = random.sample(available, needed)
                unique_connections.extend(additional)
            unique_connections = unique_connections[:D_total]
            return unique_connections
    elif args.synaptic_dist == "uniform":
        spread = getattr(args, "uniform_spread", 2)
        
        mean_connections = D_total / M
        low = mean_connections * (1 - spread)
        high = mean_connections * (1 + spread)
        
        random_values = np.random.uniform(low, high, M)
        
        group_sizes = (random_values / np.sum(random_values)) * D_total
        group_sizes = np.round(group_sizes).astype(int)
        
        diff = D_total - np.sum(group_sizes)
        while diff != 0:
            for i in range(M):
                if diff == 0:
                    break
                if diff > 0:
                    group_sizes[i] += 1
                    diff -= 1
                elif diff < 0 and group_sizes[i] > 1:
                    group_sizes[i] -= 1
                    diff += 1
        
        connections = []
        for gc, g_size in zip(group_centers, group_sizes):
            c = int(round(gc)) % N_in
            half = g_size // 2
            group = [(c - half + k) % N_in for k in range(g_size)]
            connections.extend(group)
        return connections
    elif args.synaptic_dist == "gaussian":
        # Gaussian parameters
        mean_conn = D_total / M  # Target mean connections per dendrite
        std_conn = getattr(args, "gaussian_std", mean_conn)  # Default to 50% of mean
        
        # Generate group sizes from Gaussian distribution
        group_sizes = np.random.normal(loc=mean_conn, scale=std_conn, size=M)
        
        # Ensure minimum of 1 connection per group and round to integers
        group_sizes = np.clip(group_sizes, 1, None).round().astype(int)
        
        # Adjust total to exactly match D_total
        current_total = np.sum(group_sizes)
        diff = D_total - current_total
        
        # Adjustment loop (distribute difference randomly)
        while diff != 0:
            # Get all adjustable groups (can increase if diff >0, decrease if diff <0)
            adjustable = np.where(
                (group_sizes < (2*mean_conn)) if diff > 0 else  # Allow increases
                (group_sizes > 1)                                # Allow decreases
            )[0]
            
            if len(adjustable) == 0:
                break  # Safety valve to prevent infinite loop
                
            # Randomly select a group to adjust (uniform probability)
            idx = np.random.choice(adjustable)
            
            if diff > 0:
                group_sizes[idx] += 1
                diff -= 1
            else:
                group_sizes[idx] -= 1
                diff += 1
        
        # Generate connections (same spatial pattern as before)
        connections = []
        for gc, g_size in zip(group_centers, group_sizes):
            c = int(round(gc)) % N_in
            half = g_size // 2
            group = [(c - half + k) % N_in for k in range(g_size)]
            connections.extend(group)
        
        return connections
    elif args.synaptic_dist == "spatial_gaussian":
        distances = []
        for gc in group_centers:
            d_abs = abs(gc - base_center)
            d_circ = min(d_abs, N_in - d_abs)
            distances.append(d_circ)
        
        sigma = getattr(args, "synaptic_std",0.1* N_in)
        gaussian_vals = np.exp(-0.5 * (np.array(distances)/sigma)**2)
        probs = gaussian_vals / np.sum(gaussian_vals)
        
        group_assignments = np.random.choice(M, size=D_total, p=probs)
        group_sizes = np.bincount(group_assignments, minlength=M)

        # Expand connections from group centers
        connections = []
        for gc, size in zip(group_centers, group_sizes):
            if size == 0:
                continue
            group = []
            step = 0
            direction = -1  # Start expanding left
            while len(group) < size:
                pos = (gc + direction * step) % N_in
                if pos not in group:
                    group.append(pos)
                if len(group) >= size:
                    break
                direction *= -1  # Alternate direction
                if direction == 1:
                    step += 1
            connections.extend(group)

        # Ensure exact D_total connections
        unique_connections = list(set(connections))
        if len(unique_connections) < D_total:
            padding = get_closest_nodes_centered(int(base_center), N_in, D_total - len(unique_connections))
            unique_connections.extend(padding)
        return unique_connections[:D_total]

    elif args.synaptic_dist == "spatial_inversegaussian":
        # Inverse Gaussian: farther groups get more connections
        distances = []
        for gc in group_centers:
            d_abs = abs(gc - base_center)
            d_circ = min(d_abs, N_in - d_abs)
            distances.append(d_circ)
        
        sigma = getattr(args, "gaussian_sigma", 0.1 * N_in)
        max_dist = np.max(distances)
        
        # Inverse Gaussian calculation
        inverse_distances = max_dist - np.array(distances)
        gaussian_vals = np.exp(-0.5 * (inverse_distances/sigma)**2)
        probs = gaussian_vals / np.sum(gaussian_vals)
        
        # Assign connections to groups
        group_assignments = np.random.choice(M, size=D_total, p=probs)
        group_sizes = np.bincount(group_assignments, minlength=M)

        # Expand connections from group centers
        connections = []
        for gc, size in zip(group_centers, group_sizes):
            if size == 0:
                continue
            group = []
            step = 0
            direction = 1  # Start expanding right
            while len(group) < size:
                pos = (gc + direction * step) % N_in
                if pos not in group:
                    group.append(pos)
                if len(group) >= size:
                    break
                direction *= -1  # Alternate direction
                if direction == -1:
                    step += 1
            connections.extend(group)

        # Ensure exact D_total connections
        unique_connections = list(set(connections))
        if len(unique_connections) < D_total:
            padding = get_closest_nodes_centered(int(base_center), N_in, D_total - len(unique_connections))
            unique_connections.extend(padding)
        return unique_connections[:D_total]


def _adjust_samples(samples, target_total):
    """Convert float samples to integers while preserving total and ensuring >=1"""
    rounded = np.round(samples).astype(int)
    
    # First pass: clip to minimum 1
    clipped = np.clip(rounded, 1, None)
    current_total = np.sum(clipped)
    diff = target_total - current_total

    if diff != 0:
        fractional = samples - rounded
        
        if diff > 0:
            indices = np.argsort(-fractional)[:diff]
            clipped[indices] += 1
        else:
            reducible_mask = clipped > 1
            reducible_indices = np.where(reducible_mask)[0]
            
            if len(reducible_indices) > 0:
                sorted_indices = reducible_indices[np.argsort(fractional[reducible_indices])]
                for i in range(min(-diff, len(sorted_indices))):
                    clipped[sorted_indices[i]] -= 1

    clipped = np.clip(clipped, 1, None)
    final_total = np.sum(clipped)
    
    if final_total != target_total:
        adj_diff = target_total - final_total
        if adj_diff > 0:
            for i in range(adj_diff):
                clipped[i % len(clipped)] += 1
        else:
            removed = 0
            for i in range(len(clipped)):
                if clipped[i] > 1 and removed < -adj_diff:
                    clipped[i] -= 1
                    removed += 1
    
    return clipped.tolist()


def create_dendritic_sparse_scheduler(sparsity, w, args):
    N_in = min(w.shape[0], w.shape[1])
    N_out = max(w.shape[0], w.shape[1])
    base_M = args.M 

    total_target = int(round((1 - sparsity) * N_in * N_out))
    degree_dist = getattr(args, "degree_dist", "fixed")  

    if degree_dist == "fixed":
        D_float = N_in * (1 - sparsity)
        K1 = int(D_float)
        K2 = K1 + 1
        count_K1 = int(total_target - K2 * N_out) // (K1 - K2)
        count_K2 = N_out - count_K1
        connection_counts = [K1] * count_K1 + [K2] * count_K2
    elif degree_dist == "gaussian":
        D_float = N_in * (1 - sparsity)
        connection_std = getattr(args, "degree_std", 2*D_float)
        samples = np.random.normal(D_float, connection_std, N_out)
        samples = np.clip(samples, 1, N_in)  
        connection_counts = _adjust_samples(samples, total_target)
    elif degree_dist == "uniform":
        D_float = N_in * (1 - sparsity)
        spread = getattr(args, "degree_spread", 4 * D_float)
        samples = np.random.uniform(D_float-spread, D_float+spread, N_out)
        connection_counts = _adjust_samples(samples, total_target)
    elif degree_dist in ["spatial_gaussian", "spatial_inversegaussian"]:
        center = (N_out - 1) / 2.0
        D_float = N_in * (1 - sparsity)
        sigma = getattr(args, "degree_std", D_float)
        distances = np.abs(np.arange(N_out) - center)
        
        if degree_dist == "spatial_gaussian":
            weights = np.exp(-0.5 * (distances / sigma)**2)
        else:
            max_weight = np.exp(-0.5 * (0 / sigma)**2)
            weights = max_weight - np.exp(-0.5 * (distances / sigma)**2)
            weights = np.clip(weights, 0, None)
            
        samples = weights * (total_target / np.sum(weights)) if np.sum(weights) > 0 else np.ones(N_out)*(total_target/N_out)
        connection_counts = _adjust_samples(samples, total_target)
    else:
        raise ValueError(f"Unknown degree distribution: {degree_dist}")

    if degree_dist not in ["spatial_gaussian", "spatial_inversegaussian"]:
        random.shuffle(connection_counts)
    connection_counts = np.array(connection_counts)

    # Convert to numpy array for vector operations
    connection_counts = np.array(connection_counts, dtype=int)
    
    # 1. Initial clipping to valid range
    connection_counts = np.clip(connection_counts, 1, N_in)
    
    # 2. Gradual adjustment to reach exact total
    current_total = np.sum(connection_counts)
    diff = total_target - current_total
    
    # Create adjustment sequence based on difference
    if diff != 0:
        # Calculate how many nodes we can adjust
        adjustable_inc = (connection_counts < N_in).sum()
        adjustable_dec = (connection_counts > 1).sum()
        
        # Calculate maximum possible adjustment
        max_inc = adjustable_inc * (N_in - 1)
        max_dec = adjustable_dec * (N_in - 1)
        
        if diff > 0 and diff > max_inc:
            raise ValueError(f"Cannot add {diff} connections (max possible: {max_inc})")
        if diff < 0 and -diff > max_dec:
            raise ValueError(f"Cannot remove {-diff} connections (max possible: {max_dec})")
        
        # Create probability distribution for adjustments
        probabilities = np.ones(N_out) / N_out  # Uniform distribution
        
        while diff != 0:
            if diff > 0:
                # Find all nodes that can be increased
                candidates = np.where(connection_counts < N_in)[0]
                if len(candidates) == 0:
                    break
                # Select random candidate weighted by probability
                idx = np.random.choice(candidates, p=probabilities[candidates]/probabilities[candidates].sum())
                connection_counts[idx] += 1
                diff -= 1
            else:
                # Find all nodes that can be decreased
                candidates = np.where(connection_counts > 1)[0]
                if len(candidates) == 0:
                    break
                # Select random candidate weighted by probability
                idx = np.random.choice(candidates, p=probabilities[candidates]/probabilities[candidates].sum())
                connection_counts[idx] -= 1
                diff += 1
    
    # Final validation
    final_total = np.sum(connection_counts)
    if final_total != total_target:
        raise RuntimeError(f"Connection count mismatch: {final_total} vs {total_target} (Δ={final_total-total_target})")

    if sum(connection_counts) != total_target:
        # Force correct total by adjusting first elements
        diff = total_target - sum(connection_counts)
        for i in range(abs(diff)):
            idx = i % N_out
            if diff > 0:
                connection_counts[idx] += 1
            else:
                connection_counts[idx] = max(1, connection_counts[idx] - 1)
    connection_counts = [max(1, min(c, N_in)) for c in connection_counts]
    # Initialize adjacency matrix
    adj = np.zeros((N_in, N_out), dtype=int)

    gamma = args.gamma if args.gamma is not None else 0.5
    gamma_dist = args.gamma_dist if args.gamma_dist is not None else "fixed"

    M_dist = getattr(args, "M_dist")
    M_vals = None
    gammas = []

    if M_dist in ["spatial_gaussian", "spatial_inversegaussian"]:
        center_out = (N_out - 1) / 2.0
        sigma = getattr(args, "M_std", N_out / 7.0) 
        distances = np.abs(np.arange(N_out) - center_out)
        
        if M_dist == "spatial_gaussian":
            weights = np.exp(-0.5 * (distances / sigma)**2)
        else:
            max_weight = 1.0
            min_weight = 0.1
            weights = max_weight - (max_weight - min_weight) * np.exp(-0.5 * (distances / sigma)**2)
            weights = np.clip(weights, min_weight, max_weight)
        
        sum_weights = np.sum(weights)
        if sum_weights == 0:
            weights = np.ones(N_out)  
        scaling_factor = (N_out * base_M) / sum_weights
        weights *= scaling_factor
        
        M_vals_float = weights.copy()
        M_vals = np.floor(M_vals_float).astype(int) 
        fractional = M_vals_float - M_vals
        
        remaining = (N_out * base_M) - np.sum(M_vals)
        if remaining > 0:
            probs = fractional / np.sum(fractional)
            extra_indices = np.random.choice(N_out, size=remaining, p=probs, replace=True)
            np.add.at(M_vals, extra_indices, 1)
        
        M_vals = np.clip(M_vals, 1, None)

    for j in range(N_out):
        if gamma_dist == "fixed":
            gamma_j = gamma
        elif gamma_dist == "gaussian":
            mean_gamma = gamma
            gamma_std = getattr(args, "gamma_std", gamma*0.1)
            gamma_j = np.clip(np.random.normal(mean_gamma, gamma_std), 0, 1)
            gammas.append(gamma_j)
        elif gamma_dist=="uniform":
            spread = 0.25
            gamma_j = np.random.uniform(gamma - spread, gamma + spread)
        elif gamma_dist == "spatial_gaussian":
            center_out = (N_out - 1) / 2.0
            distance = abs(j - center_out)
            max_distance = center_out
            normalized_dist = distance / max_distance
            mean_gamma_j = 1.0 - normalized_dist 
            gamma_std = getattr(args, "gamma_std", gamma*0.05)
            gamma_j = np.random.normal(mean_gamma_j, gamma_std)
            gamma_j = np.clip(gamma_j, 0.0, 1.0)
        elif gamma_dist == "spatial_inversegaussian":
            center_out = (N_out - 1) / 2.0
            distance = abs(j - center_out)
            max_distance = center_out
            mean_gamma_j = gamma * (distance / max_distance)
            gamma_std = getattr(args, "gamma_std", gamma*0.05)
            gamma_j = np.random.normal(mean_gamma_j, gamma_std)
            gamma_j = np.clip(gamma_j, 0.0, 1.0)
        else:
            raise ValueError("Unknown gamma distribution: {}".format(gamma_dist))
        
        D_total = connection_counts[j]
        if M_dist == "fixed":
            M_j = base_M
        elif M_dist == "gaussian":
            M_std = getattr(args, "M_std", base_M /4)
            M_j = np.random.normal(base_M, M_std / 2)
            M_j = int(np.round(np.clip(M_j, 1, D_total)))
        elif M_dist == "uniform":
            spread = getattr(args, "M_spread", base_M)
            M_j = np.random.uniform(base_M - spread, base_M + spread)
            M_j = int(np.round(np.clip(M_j, 1, D_total)))
        elif M_dist in ["spatial_gaussian", "spatial_inversegaussian"]:
            M_j = M_vals[j]
            M_j = max(1, min(M_j, D_total)) 
        else:
            raise ValueError(f"Unknown M distribution: {M_dist}")
        base_window = D_total
        window_size = int(round(base_window + gamma_j * (N_in - base_window)))
        window_size = max(D_total, min(window_size, N_in))

        connections = pick_connections_for_output_node(
            j, N_in, N_out, sparsity, D_total, M_j, gamma_j, args
        )
        for i in connections:
            adj[i, j] = 1
    degrees = adj.sum(axis=0)
    print("Degree stats:", np.min(degrees), np.max(degrees), np.mean(degrees))
    # --- REWIRING ---
    if args.random_rewiring != 0:
        total_edges = int(np.sum(adj))
        randomness = np.random.binomial(1, p=args.random_rewiring, size=total_edges)
        count = 0
        for i in range(N_in):
            for j in range(N_out):
                if adj[i, j] == 1:
                    if randomness[count] == 1:
                        adj[i, j] = 0 
                    count += 1

        removed_edges = total_edges - int(np.sum(adj))
        nrAdd = 0
        while nrAdd < removed_edges:
            i_rand = np.random.randint(0, N_in)
            j_rand = np.random.randint(0, N_out)
            if adj[i_rand, j_rand] == 0:
                adj[i_rand, j_rand] = 1
                nrAdd += 1
        print("After rewiring, total edges:", np.sum(adj), "removed:", removed_edges)

    if w.shape[0] != N_in:
        return torch.LongTensor(adj).to(w.device).t()
    return torch.LongTensor(adj).to(w.device)




def create_ws_sparse_scheduler(sparsity, w, args):
    indim = min(w.shape[0], w.shape[1])
    outdim = max(w.shape[0], w.shape[1])

    K = (1- sparsity) * indim
    
    K1 = int(K)
    K2 = int(K) + 1
    my_list = [K1] * int(outdim * (K2 - K)) + [K2] * int(outdim * (K-K1) + 1)
    random.shuffle(my_list)
    
    adj = np.zeros((indim, outdim))

    rate = indim/outdim
    for i in range(outdim):
        idx = [(int(i*rate-my_list[i]/2) + j) % indim for j in range(my_list[i])]
        adj[idx, i] = 1
        
    # rewiring
    if args.ws_beta != 0:
        randomness = np.random.binomial(1, p=args.ws_beta, size=int(np.sum(adj)))
        # print(randomness)
        count = 0
        for i in range(indim):
            for j in range(outdim):
                if adj[i][j] == 1:
                    if randomness[count] == 1:
                        adj[i][j] = 0
                    
                    count += 1
        
        # regrow
        noRewires = int((1-sparsity) * indim * outdim) - np.sum(adj)
        nrAdd = 0
        while (nrAdd < noRewires):
            i = np.random.randint(0, indim)
            j = np.random.randint(0, outdim)
            if adj[i][j] == 0:
                nrAdd += 1
                adj[i][j] = 1
        
        print(np.sum(adj), noRewires)
    if w.shape[0] != indim:
        return torch.LongTensor(adj).to(w.device).t()

    return torch.LongTensor(adj).to(w.device)
    # layer.weight_mask = torch.LongTensor(adj).to(layer.device)


def create_cws_sparse_scheduler(sparsity, w, args):
    indim, outdim = min(w.shape), max(w.shape)
    device = w.device

    # 1. ring-lattice ---------------------------------------------------
    k          = (1 - sparsity) * indim
    k_low      = int(np.floor(k))
    k_high     = k_low + 1
    n_high     = int(round(outdim * (k - k_low)))
    degrees    = np.array([k_high]*n_high + [k_low]*(outdim-n_high))
    np.random.shuffle(degrees)

    adj = np.zeros((indim, outdim), np.uint8)
    ratio = indim/outdim
    for col, d in enumerate(degrees):
        start = int(round(col*ratio - d/2))
        adj[(start + np.arange(d)) % indim, col] = 1

    # 2. drop edges -----------------------------------------------------
    p = getattr(args, "random_rewiring", 0.0)
    if p > 0:
        e      = np.argwhere(adj)            # (N,2)
        keep   = np.random.rand(len(e)) > p
        adj[:] = 0
        adj[e[keep,0], e[keep,1]] = 1

    # 3. global Gaussian regrowth --------------------------------------
    target  = int((1-sparsity)*indim*outdim)
    missing = target - int(adj.sum())
    if missing:
        sx, sy = args.sigma_x, args.sigma_y   # already in **cells**
        rho    = args.rho
        cov    = np.array([[sx**2, rho*sx*sy],
                           [rho*sx*sy, sy**2]])
        mean   = np.array([indim/2, outdim/2])

        batch  = 2048
        stuck  = 0
        while missing:
            pts  = np.random.multivariate_normal(mean, cov, batch)
            ii   = np.mod(np.rint(pts[:,0]).astype(int), indim)
            jj   = np.mod(np.rint(pts[:,1]).astype(int), outdim)

            flat = np.unique(ii*outdim + jj)          # drop duplicates
            free = flat[adj.ravel()[flat] == 0]       # keep only zeros
            take = free[:missing]
            adj.ravel()[take] = 1
            gained   = take.size
            missing -= gained

            # adapt batch / σ if progress is poor
            if gained < 0.01*batch:
                stuck += 1
            else:
                stuck  = 0
            if stuck >= 5:
                sx *= 1.5; sy *= 1.5
                cov[0,0], cov[1,1] = sx**2, sy**2
                cov[0,1] = cov[1,0] = rho*sx*sy
                stuck, batch = 0, min(batch*2, 65536)

            # fallback to uniform for the very last few %
            if missing and missing < 0.05*target*p:
                zeros = np.where(adj.ravel() == 0)[0]
                picks = np.random.choice(zeros, missing, replace=False)
                adj.ravel()[picks] = 1
                missing = 0

    mask = torch.from_numpy(adj).long().to(device)
    return mask.t() if w.shape[0] != indim else mask

def create_ws1_sparse_scheduler(sparsity, w, args):
    device = w.device
    indim = min(w.shape[0], w.shape[1])
    outdim = max(w.shape[0], w.shape[1])

    K = (1- sparsity) * indim
    
    K1 = int(K)
    K2 = int(K) + 1
    my_list = [K1] * int(outdim * (K2 - K)) + [K2] * int(outdim * (K-K1) + 1)
    random.shuffle(my_list)
    
    adj = np.zeros((indim, outdim))

    rate = indim/outdim
    random.shuffle(my_list)
    for i in range(outdim):
        idx = [(int(i*rate-my_list[i]/2) + j) % indim for j in range(my_list[i])]
        adj[idx, i] = 1
        
    # rewiring
    if args.random_rewiring > 0:
        target_links = int(adj.sum())
        alpha = None
        if args.delta < 1.0:
            alpha = args.delta / (1.0 - args.delta)          # exponent ≥ 0

        for j in range(outdim):                              # one *column*
            col_links = np.flatnonzero(adj[:, j])            # current 1-positions
            n_rewire  = int(len(col_links) * args.random_rewiring)
            if n_rewire == 0:
                continue

            # 1) disconnect a random subset of existing links -------------------
            victims = np.random.choice(col_links, size=n_rewire, replace=False)
            adj[victims, j] = 0

            # 2) re-grow them using S_ij-weighted sampling ----------------------
            for _ in range(n_rewire):
                free = np.flatnonzero(adj[:, j] == 0)
                if free.size == 0:
                    break  # column saturated – shouldn’t happen

                if args.delta == 1.0:
                    # δ = 1  ⇒ deterministically pick row(s) closest to diagonal
                    d  = np.abs(free - (j % indim))
                    d  = np.minimum(d, indim - d)            # ring distance
                    new_i = free[d == d.min()][0]            # pick first min-d row
                else:
                    # δ ∈ (0,1)  ⇒ sample with probability  S_ij^{-α}
                    d  = np.abs(free - (j % indim))
                    d  = np.minimum(d, indim - d)
                    S  = 1.0 + d
                    w  = S ** (-alpha)                       # weight ↓ as distance ↑
                    p  = w / w.sum()
                    new_i = np.random.choice(free, p=p)

                adj[new_i, j] = 1                           # add new link

        # (optional debug)
        assert adj.sum() == target_links, \
               "Degree changed during rewiring – check logic"
        
    mask = torch.from_numpy(adj).long().to(device)
    return mask.t() if w.shape[0] != indim else mask


def create_ws2_sparse_scheduler(sparsity, w, args):
    device = w.device
    indim  = min(w.shape[0], w.shape[1])
    outdim = max(w.shape[0], w.shape[1])

    # ---------- 1. build the deterministic ring-band sparsity  ----------
    K        = (1 - sparsity) * indim
    K1, K2   = int(K), int(K) + 1
    my_list  = [K1] * int(outdim * (K2 - K)) + [K2] * int(outdim * (K - K1) + 1)
    random.shuffle(my_list)

    adj      = np.zeros((indim, outdim), dtype=np.int8)
    rate     = indim / outdim
    for i in range(outdim):
        idx = [(int(i * rate - my_list[i] / 2) + j) % indim for j in range(my_list[i])]
        adj[idx, i] = 1

    # ---------- 2. remember that initial band so we can block it later  ----------
    band_mask = adj.astype(bool)        # True == position was on the original band

    # ---------- 3. random rewiring, but only to off-band rows  ----------
    if args.random_rewiring > 0:
        target_links = int(adj.sum())    # for the final assert
        alpha = None
        if args.delta < 1.0:
            alpha = args.delta / (1.0 - args.delta)   # α ≥ 0

        for j in range(outdim):                       # work column-wise
            col_links = np.flatnonzero(adj[:, j])
            n_rewire  = int(len(col_links) * args.random_rewiring)
            if n_rewire == 0:
                continue

            # 1) cut a random subset of existing links
            victims = np.random.choice(col_links, size=n_rewire, replace=False)
            adj[victims, j] = 0

            # 2) add the same number of off-band links
            for _ in range(n_rewire):
                free = np.flatnonzero( (adj[:, j] == 0) & (~band_mask[:, j]) )
                if free.size == 0:          # all off-band rows already filled
                    break                  # -> skip, keeps sparsity but may drop a link

                if args.delta == 1.0:
                    d      = np.abs(free - (j % indim))
                    d      = np.minimum(d, indim - d)       # ring distance
                    new_i  = free[d == d.min()][0]          # first closest row
                else:
                    d      = np.abs(free - (j % indim))
                    d      = np.minimum(d, indim - d)
                    S      = 1.0 + d
                    w_prob = S ** (-alpha)
                    p      = w_prob / w_prob.sum()
                    new_i  = np.random.choice(free, p=p)

                adj[new_i, j] = 1

        # (optional sanity-check)
        assert adj.sum() == target_links, \
               "Degree changed during rewiring – check logic"

    mask = torch.from_numpy(adj).long().to(device)
    return mask.t() if w.shape[0] != indim else mask

def create_ws3_sparse_scheduler(sparsity, w, args):
    delta = args.delta
    if not (0.0 <= sparsity < 1.0):
        raise ValueError("'sparsity' must be in [0, 1).")
    if not (0.0 <= delta <= 1.0):
        raise ValueError("'delta' must be in [0, 1].")

    device = w.device
    in_dim  = min(w.shape[0], w.shape[1])
    out_dim = max(w.shape[0], w.shape[1])

    K_float = (1.0 - sparsity) * out_dim
    total_links = int(round(K_float * in_dim))

    # build row_degrees according to degree_dist
    if getattr(args, "degree_dist", None) == "uniform":
        # start each row with 1 link
        if total_links < in_dim:
            raise ValueError("Total links < number of rows: cannot give each row degree 1")
        # remaining links to distribute
        rem = total_links - in_dim
        # initialize degrees
        row_degrees = [1] * in_dim
        # uniformly pick rows to add the remaining links
        for _ in range(rem):
            i = random.randrange(in_dim)
            row_degrees[i] += 1
    else:
        # original “rounded” distribution
        K1, K2 = int(K_float), int(K_float) + 1
        # print("K1:", K1, "K2:", K2, "K_float:", K_float)
        # how many rows get K2 vs K1 so that sum = K_float * in_dim
        row_degrees = [K1] * int(in_dim * (K2 - K_float)) + \
                      [K2] * int(in_dim * (K_float - K1) + 1)
        random.shuffle(row_degrees)

    # prepare distance-weight exponent
    alpha = None if delta == 0.0 else (1.0 - delta) / delta

    # build adjacency by sampling each row
    adj = np.zeros((in_dim, out_dim), dtype=np.int8)
    for i in range(in_dim):
        Ki = row_degrees[i]  # number of links to add in this column
        if Ki <= 0:
            continue

        center = (i* (out_dim // in_dim)) % out_dim 
        distances = np.abs(np.arange(out_dim) - center)
        distances = np.minimum(distances, out_dim - distances)  # ring distance

        if delta == 0.0:
            # pick the Ki closest columns
            chosen = np.argsort(distances)[:Ki]
        else:
            # probabilistic: sample without replacement with weight ∝ (1 + d)^(-α)
            S = 1.0 + distances
            weights = S ** (-alpha)
            weights = weights / weights.sum()
            chosen = np.random.choice(out_dim, size=Ki, replace=False, p=weights)

        try:
            adj[i, chosen] = 1
        except IndexError:
            print("Error in adjacency assignment")
            raise IndexError("Invalid adjacency assignment")

    mask = torch.from_numpy(adj).long().to(device)
    # transpose if w was “transposed” relative to out_dim
    return mask.t() if w.shape[1] != out_dim else mask





def create_ws_cross_scheduler(sparsity, w, args):
    """
    Watts-Strogatz 'cross' variant:
      – Same expected in-degree (K) as the original scheduler
      – Half of the neighbours lie on the main diagonal, the other half on the anti-diagonal
      – Rewiring and regrowth behave exactly as before
    """
    device = w.device
    indim  = min(w.shape[0], w.shape[1])
    outdim = max(w.shape[0], w.shape[1])

    # ---------- degree distribution ----------------------------------------------------------
    K_float = (1.0 - sparsity) * indim
    K1, K2  = int(K_float), int(K_float) + 1
    my_list = np.empty(outdim, dtype=np.int16)
    my_list[: int(outdim * (K2 - K_float))]         = K1
    my_list[int(outdim * (K2 - K_float)) :]         = K2
    np.random.shuffle(my_list)

    # ---------- make the two diagonals -------------------------------------------------------
    # offsets around each centre: [-k//2, …, 0, … k//2]
    max_k = my_list.max()          # small (≲ dozens), good for broadcasting
    offsets = np.arange(max_k) - max_k // 2         # shape (max_k,)

    # per-column centre of each diagonal
    cols = np.arange(outdim)
    rate = indim / outdim
    c_main  = (cols * rate).astype(np.int64)           # ↘︎
    c_cross = (indim - 1 - cols * rate).astype(np.int64)  # ↙︎

    # build row indices with broadcasting, then mod & clip
    idx_main  = (c_main[:, None]  + offsets[None, :]) % indim  # (outdim, max_k)
    idx_cross = (c_cross[:, None] + offsets[None, :]) % indim

    # create dense mask on the CPU first
    adj = np.zeros((indim, outdim), dtype=np.int8)

    # we cannot write whole (outdim, max_k) blocks because each column has its own k
    for col, k in enumerate(my_list):
        if k == 0:
            continue
        k_main  = (k + 1) // 2
        k_cross =  k       // 2
        adj[idx_main[col, :k_main],   col] = 1
        adj[idx_cross[col, :k_cross], col] = 1

    nnz = adj.sum()

    # ---------- rewiring ---------------------------------------------------------------------
    if args.random_rewiring:
        # 1. pick the edges to drop -------------------------------------------
        ones_r, ones_c = np.nonzero(adj)
        keep = np.random.rand(nnz) >= args.random_rewiring
        adj[ones_r[~keep], ones_c[~keep]] = 0
        nnz = keep.sum()

        # 2. regrow anywhere we still have zeros --------------------------------
        target = int((1.0 - sparsity) * indim * outdim)
        missing = target - nnz
        if missing > 0:
            zeros_r, zeros_c = np.nonzero(adj == 0)          # all zero cells
            pick = np.random.choice(len(zeros_r), missing, replace=False)
            adj[zeros_r[pick], zeros_c[pick]] = 1
            nnz = target                                     # now full

    # ---------- tensor to the right shape and device -------------------------
    adj_t = torch.from_numpy(adj).to(device, dtype=torch.long)
    return adj_t.t() if w.shape[0] != indim else adj_t



def spatial_sort_order(N):
    """
    Return a list of length N giving the “center→halves→quarters→…” positions:
    e.g. for N=8 → [4,2,5,1,3,6,0,7]
    """
    positions = []
    k = 1
    Nminus1 = N - 1
    while len(positions) < N:
        for m in range(k):
            pos = round((2*m + 1) / (2*k) * Nminus1)
            if pos not in positions:
                positions.append(pos)
                if len(positions) == N:
                    break
        k *= 2
    return positions

def create_BHI_sparse_scheduler(sparsity, w, args, is_last_layer=False):
    Ni, No = w.shape
    transposed = False
    # Ensure N_out >= N_in
    if No < Ni or (No % Ni) != 0:
        Ni, No = No, Ni
        transposed = True

    adj, *_ = nPSO_bipartite(
        Ni, No,
        sparsity,
        args.BHI_T,
        args.BHI_gamma,
        args.BHI_distr,
        args.rewire_mode
    )

    # get dense matrix
    if hasattr(adj, "toarray"):
        mat = adj.toarray().astype(int)
    else:
        mat = np.array(adj, dtype=int)

    def spatial_sort_order(N):
        order = []
        queue = [(0, N)]
        while queue and len(order) < N:
            start, end = queue.pop(0)
            length = end - start
            if length <= 0:
                continue
            mid = start + (length - 1) // 2
            order.append(mid)
            queue.append((start, mid))
            queue.append((mid + 1, end))
        return order

    if args.degree_allocation:
        # 1) permute columns by input‐degree
        deg_in = mat.sum(axis=0)
        sorted_in = np.argsort(-deg_in)
        pos_in = spatial_sort_order(No)
        col_perm = np.empty(No, dtype=int)
        for idx, old in enumerate(sorted_in):
            target = pos_in[idx] if idx < len(pos_in) else idx
            col_perm[target] = old
        col_perm = col_perm.astype(int)
        mat = mat[:, col_perm]

        # 2) permute rows by output‐degree if not last layer
        if not is_last_layer:
            deg_out = mat.sum(axis=1)
            sorted_out = np.argsort(-deg_out)
            pos_out = spatial_sort_order(Ni)
            row_perm = np.empty(Ni, dtype=int)
            for idx, old in enumerate(sorted_out):
                target = pos_out[idx] if idx < len(pos_out) else idx
                row_perm[target] = old
            row_perm = row_perm.astype(int)
            mat = mat[row_perm, :]

    mask = torch.from_numpy(mat.astype(np.int64))
    if transposed:
        mask = mask.t()
    return mask.to(w.device)


def create_QHI_sparse_scheduler(sparsity, W, args):
    # W: list of weight matrices [W0, W1, W2]
    # infer the four layer sizes A, B, C, D
    A = W[0].shape[1]
    B = W[0].shape[0]
    C = W[1].shape[0]
    D = W[2].shape[0]

    # map args.BHI_distr to quadpartite distr
    distr = 'random' if args.BHI_distr == 0 else 'community'

    # build all three bipartite masks at once
    x_DC, x_CB, x_BA, coords_A, coords_B, coords_C, coords_D = nPSO_quadpartite(
        A, B, C, D,
        sparsity,
        args.BHI_T,
        args.BHI_gamma,
        distr
    )

    # convert to torch masks matching W's devices
    mask0 = torch.LongTensor(x_BA).to(W[0].device)  # for W0 (B×A)
    mask1 = torch.LongTensor(x_CB).to(W[1].device)  # for W1 (C×B)
    mask2 = torch.LongTensor(x_DC).to(W[2].device)  # for W2 (D×C)

    return [mask0, mask1, mask2]




def generate_barabasi_alberta_graph(N, m):
    adj = np.zeros((N, N))
    if not isinstance(m, int):
        print("m is not an integer")
        m1 = int(m)
        m2 = int(m) + 1
        
        adj[:m2, :m2] = np.triu(np.ones((m2, m2)), k=1) + np.triu(np.ones((m2, m2)), k=1).T
        my_list = [m1] * int((N-m2) * (m2 - m)) + [m2] * int((N-m2) * (m-m1) + 1)
        random.shuffle(my_list)
        for i in range(m2, N):
            targets = np.arange(i)
            p_normalized = np.sum(adj[:i, :i], axis=1) / np.sum(np.sum(adj[:i, :i], axis=1))

            
            m_tmp = my_list[i-m2]
            idx = np.random.choice(targets, size=m_tmp, replace=False, p=p_normalized)
            adj[i, idx] = 1
            adj[idx, i] = 1
            
    return adj 

def create_ba_sparse(layer):
    m = (1- layer.sparsity) * layer.indim * layer.outdim / (layer.indim + layer.outdim)
    adj = generate_barabasi_alberta_graph(layer.indim + layer.outdim, m)
    nodes = list(range(layer.indim + layer.outdim))
    random.shuffle(nodes)
    
    layer_N = list(set(nodes[:layer.indim]))
    layer_M = list(set(nodes[layer.indim:]))

    adj = adj[layer_N+layer_M]
    adj = adj[:, layer_N+layer_M]
    
    adj = np.triu(adj, k=1)
    final_adj = adj[:layer.indim, layer.indim:]
    layer1_fru = np.array(adj[:layer.indim, :layer.indim].nonzero()).reshape(-1)
    layer2_fru = np.array(adj[layer.indim:, layer.indim:].nonzero()).reshape(-1)
    np.random.shuffle(layer1_fru)
    np.random.shuffle(layer2_fru)

    if len(layer1_fru) <= len(layer2_fru):
        # print("in")
        layer2_fru_flag = np.zeros_like(layer2_fru)
        for i in range(len(layer1_fru)):
            for j in range(len(layer2_fru)):
                if layer2_fru_flag[j] == 0:
                    if final_adj[layer1_fru[i], layer2_fru[j]]:
                        continue
                    else:
                        final_adj[layer1_fru[i], layer2_fru[j]] = 1
                        layer2_fru_flag[j] = 1
                        break
                else:
                    continue
             
        zero_indices = np.where(layer2_fru_flag == 0)[0]
        layer2_r = layer2_fru[zero_indices]
        for i in range(len(layer2_r)//2):
            node_degrees = np.sum(final_adj, axis=1)
            prob_distribution = node_degrees / np.sum(node_degrees)
            while True:
                target_node = np.random.choice(np.arange(layer.indim), size=1, replace=False, p=prob_distribution)
                if final_adj[target_node, layer2_r[i]] == 1:
                    continue
                else:
                    final_adj[target_node, layer2_r[i]] = 1
                    break
                
    elif len(layer1_fru) > len(layer2_fru):
        layer1_fru_flag = np.zeros_like(layer1_fru)
        for i in range(len(layer2_fru)):
            for j in range(len(layer1_fru)):
                if layer1_fru_flag[j] == 0:
                    if final_adj[layer1_fru[j], layer2_fru[i]]:
                        continue
                    else:
                        final_adj[layer1_fru[j], layer2_fru[i]] = 1
                        layer1_fru_flag[j] = 1
                        break
                else:
                    continue
        zero_indices = np.where(layer1_fru_flag == 0)[0]
        layer1_r = layer1_fru[zero_indices]
        for i in range(len(layer1_r)//2):
            node_degrees = np.sum(final_adj, axis=0)
            prob_distribution = node_degrees / np.sum(node_degrees)
            while True:
                target_node = np.random.choice(np.arange(layer.outdim), size=1, replace=False, p=prob_distribution)
                if final_adj[layer1_r[i], target_node] == 1:
                    continue
                else:
                    final_adj[layer1_r[i], target_node] = 1
                    break
    print(int(np.sum(final_adj)))
    layer.weight_mask = torch.Tensor(final_adj).to(layer.device)
    
def create_self_correlated_sparse(model, corr, dim, soft=False, noise=False):
    isnan = np.isnan(corr)
    corr[isnan] = 0
    for i in range(corr.shape[0]):
        corr[i, i] = 0
    
    if noise:
        corr += np.random.randn(corr.shape[0], corr.shape[1])
    # 1x of the dimension
    if dim == 1:
        for i in range(len(model.sparse_layers)):
            number_of_links = model.sparse_layers[i].n_params
            update_topology(model.sparse_layers[i], corr, number_of_links, soft)

    # 2x of the dimension
    elif dim == 2:
        dimension = corr.shape[0] * 2
        expanded_dimension = np.zeros((dimension, dimension))
        expanded_dimension[:dimension//2, :dimension//2] = corr
        expanded_dimension[:dimension//2, dimension//2:] = corr
        expanded_dimension[dimension//2:, :dimension//2] = corr
        expanded_dimension[dimension//2:, dimension//2:] = corr
        
        for i in range(len(model.sparse_layers)):
            number_of_links = model.sparse_layers[i].n_params
            if i == 0:
                first_layer = expanded_dimension[:dimension//2, :].copy()
                update_topology(model.sparse_layers[i], first_layer, number_of_links, soft)
            else:
                update_topology(model.sparse_layers[i], expanded_dimension, number_of_links, soft)
    
def update_topology(layer, corr, number_of_links, soft=False):
    adj = torch.zeros_like(torch.Tensor(corr))
    corr_flatten = torch.abs(torch.Tensor(corr).flatten())
    if soft:
        probabilities = corr_flatten / corr_flatten.sum()
        sampled_flat_indices = torch.multinomial(probabilities, number_of_links, replacement=False)
        adj = adj.reshape(-1)
        adj[sampled_flat_indices] = 1
        adj = adj.reshape(corr.shape[0], corr.shape[1])
    else:
        threshold = torch.abs(torch.sort(-torch.abs(corr_flatten))[0][number_of_links-1])
        corr = torch.Tensor(corr)
        adj[torch.abs(corr)>=threshold]=1
        adj[torch.abs(corr)<threshold]=0

    layer.weight_mask = adj.to(layer.device)


def create_self_correlated_scheduler(model, corr, dim):
    isnan = np.isnan(corr)
    corr[isnan] = 0
    for i in range(corr.shape[0]):
        corr[i, i] = 0

    # 1x of the dimension
    if dim == 1:
        for i in range(len(model.sparse_layers)):
            number_of_links = model.sparse_layers[i].n_params
            update_topology(model.sparse_layers[i], corr, number_of_links)

    # 2x of the dimension
    elif dim == 2:
        dimension = corr.shape[0] * 2
        expanded_dimension = np.zeros((dimension, dimension))
        expanded_dimension[:dimension//2, :dimension//2] = corr
        expanded_dimension[:dimension//2, dimension//2:] = corr
        expanded_dimension[dimension//2:, :dimension//2] = corr
        expanded_dimension[dimension//2:, dimension//2:] = corr
        
        for i in range(len(model.sparse_layers)):
            number_of_links = model.sparse_layers[i].n_params
            if i == 0:
                first_layer = expanded_dimension[:dimension//2, :].copy()
                update_topology(model.sparse_layers[i], first_layer, number_of_links)
            else:
                update_topology(model.sparse_layers[i], expanded_dimension, number_of_links)

def update_topology_scheduler(w, corr, number_of_links):
    adj = torch.zeros_like(torch.Tensor(corr))
    corr_flatten = torch.abs(torch.Tensor(corr).flatten())
    
    threshold = torch.abs(torch.sort(-torch.abs(corr_flatten))[0][number_of_links])
    corr = torch.Tensor(corr)
    adj[torch.abs(corr)>=threshold]=1
    adj[torch.abs(corr)<threshold]=0
    # print(number_of_links)
    # print(torch.sum(adj))
    return adj.to(w.device)

def update_topology_scheduler_soft(w, corr, number_of_links):
    corr = torch.Tensor(corr)
    adj = torch.zeros_like(corr)
    corr_abs = torch.abs(corr)
    corr_flatten = corr_abs.flatten()
    probabilities = corr_flatten / corr_flatten.sum()
    sampled_indices = torch.multinomial(probabilities, number_of_links, replacement=False)
    adj_flatten = adj.flatten()
    adj_flatten[sampled_indices] = 1
    return adj_flatten.view(corr.shape).to(w.device)