import torch
from torchvision import datasets, transforms
from scipy.io import loadmat, savemat
import numpy as np
from .nPSO_monopartite_bipartite import nPSO_bipartite
from .nPSO_quadripartite import nPSO_quadpartite
import random
from scipy import sparse
import torch.nn.functional as F

def load_calib_dataset(args, data_dir='./data'):
    if args.dataset == "MNIST":
        dataloader = torch.utils.data.DataLoader(
                        datasets.MNIST(data_dir, train=True, download=True,
                                    transform=transforms.Compose([
                                        transforms.ToTensor()
                                    ])),
                        batch_size=args.calib_samples, shuffle=True)
        input_of_sparse_layer = np.zeros((784,60000))
    elif args.dataset == "Fashion_MNIST":
        dataloader= torch.utils.data.DataLoader(datasets.FashionMNIST(
                    root=data_dir,
                    train=True,
                    transform=transforms.Compose([
                        transforms.ToTensor()
                        # transforms.Normalize((0.1307,), (0.3081,))
                    ]),
                    download=True),
                    batch_size=args.calib_samples,
                    shuffle=True)
        input_of_sparse_layer = np.zeros((784,60000))
    elif args.dataset == "EMNIST":
        dataloader = torch.utils.data.DataLoader(datasets.EMNIST(
                    root=data_dir,
                    train=True,
                    transform=transforms.Compose([
                        transforms.ToTensor()
                    ]),
                    download=True,
                    split='balanced'),
                    batch_size=args.calib_samples,
                    shuffle=True)
        input_of_sparse_layer = np.zeros((784,112800))
    return dataloader, input_of_sparse_layer

def rewire_connections(layer):
    new_matrix = torch.zeros_like(layer.weight_mask).to(layer.device)

    cols = new_matrix.shape[1]
    rows = new_matrix.shape[0]

    for i in range(cols):
        column = layer.weight_mask[:, i]

        num_connections = column.nonzero().numel()
        

        new_positions = torch.randperm(rows)[:num_connections]

        new_matrix[new_positions, i] = column[column != 0]

    layer.weight_mask = new_matrix
                


def create_sparse_topological_initialization(args, model, filename=None, eng=None):
    
    if args.self_correlated_sparse:
        dataloader, input_of_sparse_layer = load_calib_dataset(args, data_dir='./data')

        print("Using self correlated sparse of mlp!!!")
        
        import os
        if os.path.exists(filename):
            corr = loadmat(filename + "/corr.mat")["corr"]
        else:
            for batch_idx, (data, _) in enumerate(dataloader):
                input_of_sparse_layer[:,batch_idx*args.calib_samples:batch_idx*args.calib_samples + data.shape[0]] = data.reshape(-1, 784).numpy().transpose(1, 0)
            corr = np.corrcoef(input_of_sparse_layer)
            os.makedirs(filename)
            print("done")
            
            savemat(filename + "/corr.mat", {"corr":corr})
        create_self_correlated_sparse(model, corr, args.dim, args.soft_csti, args.noise_csti)
    
    elif args.BA:
        for i, layer in enumerate(model.sparse_layers):
            create_ba_sparse(layer) 
            if i == 0 and args.rewire_first_layer:
                rewire_connections(layer)
            

    elif args.WS:
        for layer in model.sparse_layers:
            create_ws_sparse(layer, args)
            
                
def soft_resort(in_neuron_degree):
    sampled_indices = torch.multinomial(in_neuron_degree, num_samples=in_neuron_degree.shape[0], replacement=False)
    return sampled_indices


def create_ws_sparse(layer, args):
    indim = min(w.shape[0], w.shape[1])
    outdim = max(w.shape[0], w.shape[1])

    K = (1- args.sparsity) * indim
    
    K1 = int(K)
    K2 = int(K) + 1
    my_list = [K1] * int(outdim * (K2 - K)) + [K2] * int(outdim * (K-K1) + 1)
    random.shuffle(my_list)
    
    adj = np.zeros((indim, outdim))

    rate = indim/outdim
    random.shuffle(my_list)
    for i in range(outdim):
        idx = [(int(i*rate-my_list[i]/2) + j) % indim for j in range(my_list[i])]
        adj[idx, i] = 1 
        
    # rewiring
    if args.random_rewiring != 0:
        randomness = np.random.binomial(1, p=args.random_rewiring, size=int(np.sum(adj)))
        count = 0
        for i in range(indim):
            for j in range(outdim):
                if adj[i][j] == 1:
                    if randomness[count] == 1:
                        adj[i][j] = 0
                    
                    count += 1
        
        # regrow
        noRewires = int((1-layer.sparsity) * indim * outdim) - np.sum(adj)
        nrAdd = 0
        while (nrAdd < noRewires):
            i = np.random.randint(0, indim)
            j = np.random.randint(0, outdim)
            if adj[i][j] == 0:
                nrAdd += 1
                adj[i][j] = 1
        
        print(np.sum(adj), noRewires)

    if layer.indim != indim:
        layer.weight_mask = torch.Tensor(adj).to(layer.device).t()
    else:
        layer.weight_mask = torch.Tensor(adj).to(layer.device)

# def _adjust_samples(samples, target_total):
#     """Convert float samples to integers using probabilistic rounding"""
#     # Calculate integer parts and fractional remainders
#     integer_parts = np.floor(samples).astype(int)
#     fractional = samples - integer_parts
    
#     # Calculate how many connections we need to add
#     total_integer = np.sum(integer_parts)
#     remainder = target_total - total_integer
    
#     # Probabilistically distribute remaining connections
#     if remainder > 0:
#         # Get probabilities from fractional parts
#         probs = fractional / np.sum(fractional)
#         # Randomly choose which indices get extra connection
#         extra_indices = np.random.choice(len(samples), size=remainder, p=probs, replace=True)
#         np.add.at(integer_parts, extra_indices, 1)
    
#     return integer_parts.tolist()


# def symmetric_positions(center, D_total, window_size, N_in):
#     """
#     Compute D_total positions evenly spaced within a window of length window_size,
#     symmetric about the given center.
#     """
#     half_window = (window_size - 1) // 2
#     positions = np.linspace(center - half_window, center + half_window, D_total, dtype=int)

#     positions = [(x % N_in) for x in positions]

#     unique_positions = list(set(positions))

#     while len(unique_positions) < D_total:
#         unique_positions.extend(unique_positions)  
#         unique_positions = unique_positions[:D_total] 

#     return unique_positions


# def get_closest_nodes_centered(i, N, count):
#     """
#     Returns a list of 'count' indices from a layer of size N,
#     ordered as: i, i-1, i+1, i-2, i+2, … (with modulo wrapping)
#     """
#     indices = [i]
#     d = 1
#     while len(indices) < count:
#         indices.append((i - d) % N)
#         if len(indices) < count:
#             indices.append((i + d) % N)
#         d += 1
#     return indices[:count]


# def pick_connections_for_output_node(j, N_in, N_out, sparsity, D_total, M, gamma, args):
#     """
#     For output node j, returns a list of input node indices that will be connected.
#     """
#     base_center = j * (N_in / N_out)
#     center = int(round(base_center)) % N_in

#     base_window = D_total
#     window_size = int(round(base_window + gamma * (N_in - base_window)))
#     window_size = max(D_total, min(window_size, N_in))

#     window_start = center - window_size // 2
#     uniform_centers = np.linspace(window_start, window_start + window_size, M, endpoint=False) + window_size/(2*M)
#     group_centers = (1 - gamma) * center + gamma * uniform_centers
#     group_centers = [int(round(gc)) % N_in for gc in group_centers]
    
#     if args.synaptic_dist == "fixed":
#         if gamma == 0:
#             return symmetric_positions(center, D_total, window_size, N_in)
#         else:
#             # Partition D_total over M groups
#             base = D_total // M
#             remainder = D_total % M
#             group_sizes = [base] * M
#             extra_indices = random.sample(range(M), remainder)
#             for idx in extra_indices:
#                 group_sizes[idx] += 1

#             connections = []
#             for gc, g_size in zip(group_centers, group_sizes):
#                 c = int(round(gc)) % N_in
#                 half = g_size // 2
#                 # Pick g_size contiguous nodes around c
#                 group = [(c - half + k) % N_in for k in range(g_size)]
#                 connections.extend(group)
            
#             # Force uniqueness
#             unique_connections = list(set(connections))
#             if len(unique_connections) < D_total:
#                 # Instead of padding from a global nearest list, sample additional nodes from the designated window.
#                 possible_window = [(window_start + k) % N_in for k in range(window_size)]
#                 # Exclude what we already have.
#                 available = list(set(possible_window) - set(unique_connections))
#                 needed = D_total - len(unique_connections)
#                 # If the window does not contain enough nodes, sample from the whole range without duplicates.
#                 if len(available) < needed:
#                     available = list(set(range(N_in)) - set(unique_connections))
#                 additional = random.sample(available, needed)
#                 unique_connections.extend(additional)
#             unique_connections = unique_connections[:D_total]
#             return unique_connections
#     elif args.synaptic_dist == "uniform":
#         spread = getattr(args, "uniform_spread", 2)
        
#         mean_connections = D_total / M
#         low = mean_connections * (1 - spread)
#         high = mean_connections * (1 + spread)
        
#         random_values = np.random.uniform(low, high, M)
        
#         group_sizes = (random_values / np.sum(random_values)) * D_total
#         group_sizes = np.round(group_sizes).astype(int)
        
#         diff = D_total - np.sum(group_sizes)
#         while diff != 0:
#             for i in range(M):
#                 if diff == 0:
#                     break
#                 if diff > 0:
#                     group_sizes[i] += 1
#                     diff -= 1
#                 elif diff < 0 and group_sizes[i] > 1:
#                     group_sizes[i] -= 1
#                     diff += 1
        
#         connections = []
#         for gc, g_size in zip(group_centers, group_sizes):
#             c = int(round(gc)) % N_in
#             half = g_size // 2
#             group = [(c - half + k) % N_in for k in range(g_size)]
#             connections.extend(group)
#         return connections
#     elif args.synaptic_dist == "gaussian":
#         # Gaussian parameters
#         mean_conn = D_total / M  # Target mean connections per dendrite
#         std_conn = getattr(args, "gaussian_std", mean_conn)  # Default to 50% of mean
        
#         # Generate group sizes from Gaussian distribution
#         group_sizes = np.random.normal(loc=mean_conn, scale=std_conn, size=M)
        
#         # Ensure minimum of 1 connection per group and round to integers
#         group_sizes = np.clip(group_sizes, 1, None).round().astype(int)
        
#         # Adjust total to exactly match D_total
#         current_total = np.sum(group_sizes)
#         diff = D_total - current_total
        
#         # Adjustment loop (distribute difference randomly)
#         while diff != 0:
#             # Get all adjustable groups (can increase if diff >0, decrease if diff <0)
#             adjustable = np.where(
#                 (group_sizes < (2*mean_conn)) if diff > 0 else  # Allow increases
#                 (group_sizes > 1)                                # Allow decreases
#             )[0]
            
#             if len(adjustable) == 0:
#                 break  # Safety valve to prevent infinite loop
                
#             # Randomly select a group to adjust (uniform probability)
#             idx = np.random.choice(adjustable)
            
#             if diff > 0:
#                 group_sizes[idx] += 1
#                 diff -= 1
#             else:
#                 group_sizes[idx] -= 1
#                 diff += 1
        
#         # Generate connections (same spatial pattern as before)
#         connections = []
#         for gc, g_size in zip(group_centers, group_sizes):
#             c = int(round(gc)) % N_in
#             half = g_size // 2
#             group = [(c - half + k) % N_in for k in range(g_size)]
#             connections.extend(group)
        
#         return connections
#     elif args.synaptic_dist == "spatial_gaussian":
#         distances = []
#         for gc in group_centers:
#             d_abs = abs(gc - base_center)
#             d_circ = min(d_abs, N_in - d_abs)
#             distances.append(d_circ)
        
#         sigma = getattr(args, "synaptic_std",0.1* N_in)
#         gaussian_vals = np.exp(-0.5 * (np.array(distances)/sigma)**2)
#         probs = gaussian_vals / np.sum(gaussian_vals)
        
#         group_assignments = np.random.choice(M, size=D_total, p=probs)
#         group_sizes = np.bincount(group_assignments, minlength=M)

#         # Expand connections from group centers
#         connections = []
#         for gc, size in zip(group_centers, group_sizes):
#             if size == 0:
#                 continue
#             group = []
#             step = 0
#             direction = -1  # Start expanding left
#             while len(group) < size:
#                 pos = (gc + direction * step) % N_in
#                 if pos not in group:
#                     group.append(pos)
#                 if len(group) >= size:
#                     break
#                 direction *= -1  # Alternate direction
#                 if direction == 1:
#                     step += 1
#             connections.extend(group)

#         # Ensure exact D_total connections
#         unique_connections = list(set(connections))
#         if len(unique_connections) < D_total:
#             padding = get_closest_nodes_centered(int(base_center), N_in, D_total - len(unique_connections))
#             unique_connections.extend(padding)
#         return unique_connections[:D_total]

#     elif args.synaptic_dist == "spatial_inversegaussian":
#         # Inverse Gaussian: farther groups get more connections
#         distances = []
#         for gc in group_centers:
#             d_abs = abs(gc - base_center)
#             d_circ = min(d_abs, N_in - d_abs)
#             distances.append(d_circ)
        
#         sigma = getattr(args, "gaussian_sigma", 0.1 * N_in)
#         max_dist = np.max(distances)
        
#         # Inverse Gaussian calculation
#         inverse_distances = max_dist - np.array(distances)
#         gaussian_vals = np.exp(-0.5 * (inverse_distances/sigma)**2)
#         probs = gaussian_vals / np.sum(gaussian_vals)
        
#         # Assign connections to groups
#         group_assignments = np.random.choice(M, size=D_total, p=probs)
#         group_sizes = np.bincount(group_assignments, minlength=M)

#         # Expand connections from group centers
#         connections = []
#         for gc, size in zip(group_centers, group_sizes):
#             if size == 0:
#                 continue
#             group = []
#             step = 0
#             direction = 1  # Start expanding right
#             while len(group) < size:
#                 pos = (gc + direction * step) % N_in
#                 if pos not in group:
#                     group.append(pos)
#                 if len(group) >= size:
#                     break
#                 direction *= -1  # Alternate direction
#                 if direction == -1:
#                     step += 1
#             connections.extend(group)

#         # Ensure exact D_total connections
#         unique_connections = list(set(connections))
#         if len(unique_connections) < D_total:
#             padding = get_closest_nodes_centered(int(base_center), N_in, D_total - len(unique_connections))
#             unique_connections.extend(padding)
#         return unique_connections[:D_total]


# def _adjust_samples(samples, target_total):
#     """Convert float samples to integers while preserving total and ensuring >=1"""
#     rounded = np.round(samples).astype(int)
    
#     # First pass: clip to minimum 1
#     clipped = np.clip(rounded, 1, None)
#     current_total = np.sum(clipped)
#     diff = target_total - current_total

#     if diff != 0:
#         fractional = samples - rounded
        
#         if diff > 0:
#             indices = np.argsort(-fractional)[:diff]
#             clipped[indices] += 1
#         else:
#             reducible_mask = clipped > 1
#             reducible_indices = np.where(reducible_mask)[0]
            
#             if len(reducible_indices) > 0:
#                 sorted_indices = reducible_indices[np.argsort(fractional[reducible_indices])]
#                 for i in range(min(-diff, len(sorted_indices))):
#                     clipped[sorted_indices[i]] -= 1

#     clipped = np.clip(clipped, 1, None)
#     final_total = np.sum(clipped)
    
#     if final_total != target_total:
#         adj_diff = target_total - final_total
#         if adj_diff > 0:
#             for i in range(adj_diff):
#                 clipped[i % len(clipped)] += 1
#         else:
#             removed = 0
#             for i in range(len(clipped)):
#                 if clipped[i] > 1 and removed < -adj_diff:
#                     clipped[i] -= 1
#                     removed += 1
    
#     return clipped.tolist()


# def create_dendritic_sparse_scheduler(sparsity, w, args):
#     N_in = min(w.shape[0], w.shape[1])
#     N_out = max(w.shape[0], w.shape[1])
#     base_M = args.M 

#     total_target = int(round((1 - sparsity) * N_in * N_out))
#     degree_dist = getattr(args, "degree_dist", "fixed")  

#     if degree_dist == "fixed":
#         D_float = N_in * (1 - sparsity)
#         K1 = int(D_float)
#         K2 = K1 + 1
#         count_K1 = int(total_target - K2 * N_out) // (K1 - K2)
#         count_K2 = N_out - count_K1
#         connection_counts = [K1] * count_K1 + [K2] * count_K2
#     elif degree_dist == "gaussian":
#         D_float = N_in * (1 - sparsity)
#         connection_std = getattr(args, "degree_std", 2*D_float)
#         samples = np.random.normal(D_float, connection_std, N_out)
#         samples = np.clip(samples, 1, N_in)  
#         connection_counts = _adjust_samples(samples, total_target)
#     elif degree_dist == "uniform":
#         D_float = N_in * (1 - sparsity)
#         spread = getattr(args, "degree_spread", 4 * D_float)
#         samples = np.random.uniform(D_float-spread, D_float+spread, N_out)
#         connection_counts = _adjust_samples(samples, total_target)
#     elif degree_dist in ["spatial_gaussian", "spatial_inversegaussian"]:
#         center = (N_out - 1) / 2.0
#         D_float = N_in * (1 - sparsity)
#         sigma = getattr(args, "degree_std", D_float)
#         distances = np.abs(np.arange(N_out) - center)
        
#         if degree_dist == "spatial_gaussian":
#             weights = np.exp(-0.5 * (distances / sigma)**2)
#         else:
#             max_weight = np.exp(-0.5 * (0 / sigma)**2)
#             weights = max_weight - np.exp(-0.5 * (distances / sigma)**2)
#             weights = np.clip(weights, 0, None)
            
#         samples = weights * (total_target / np.sum(weights)) if np.sum(weights) > 0 else np.ones(N_out)*(total_target/N_out)
#         connection_counts = _adjust_samples(samples, total_target)
#     else:
#         raise ValueError(f"Unknown degree distribution: {degree_dist}")

#     if degree_dist not in ["spatial_gaussian", "spatial_inversegaussian"]:
#         random.shuffle(connection_counts)
#     connection_counts = np.array(connection_counts)

#     # Convert to numpy array for vector operations
#     connection_counts = np.array(connection_counts, dtype=int)
    
#     # 1. Initial clipping to valid range
#     connection_counts = np.clip(connection_counts, 1, N_in)
    
#     # 2. Gradual adjustment to reach exact total
#     current_total = np.sum(connection_counts)
#     diff = total_target - current_total
    
#     # Create adjustment sequence based on difference
#     if diff != 0:
#         # Calculate how many nodes we can adjust
#         adjustable_inc = (connection_counts < N_in).sum()
#         adjustable_dec = (connection_counts > 1).sum()
        
#         # Calculate maximum possible adjustment
#         max_inc = adjustable_inc * (N_in - 1)
#         max_dec = adjustable_dec * (N_in - 1)
        
#         if diff > 0 and diff > max_inc:
#             raise ValueError(f"Cannot add {diff} connections (max possible: {max_inc})")
#         if diff < 0 and -diff > max_dec:
#             raise ValueError(f"Cannot remove {-diff} connections (max possible: {max_dec})")
        
#         # Create probability distribution for adjustments
#         probabilities = np.ones(N_out) / N_out  # Uniform distribution
        
#         while diff != 0:
#             if diff > 0:
#                 # Find all nodes that can be increased
#                 candidates = np.where(connection_counts < N_in)[0]
#                 if len(candidates) == 0:
#                     break
#                 # Select random candidate weighted by probability
#                 idx = np.random.choice(candidates, p=probabilities[candidates]/probabilities[candidates].sum())
#                 connection_counts[idx] += 1
#                 diff -= 1
#             else:
#                 # Find all nodes that can be decreased
#                 candidates = np.where(connection_counts > 1)[0]
#                 if len(candidates) == 0:
#                     break
#                 # Select random candidate weighted by probability
#                 idx = np.random.choice(candidates, p=probabilities[candidates]/probabilities[candidates].sum())
#                 connection_counts[idx] -= 1
#                 diff += 1
    
#     # Final validation
#     final_total = np.sum(connection_counts)
#     if final_total != total_target:
#         raise RuntimeError(f"Connection count mismatch: {final_total} vs {total_target} (Δ={final_total-total_target})")

#     if sum(connection_counts) != total_target:
#         # Force correct total by adjusting first elements
#         diff = total_target - sum(connection_counts)
#         for i in range(abs(diff)):
#             idx = i % N_out
#             if diff > 0:
#                 connection_counts[idx] += 1
#             else:
#                 connection_counts[idx] = max(1, connection_counts[idx] - 1)
#     connection_counts = [max(1, min(c, N_in)) for c in connection_counts]
#     # Initialize adjacency matrix
#     adj = np.zeros((N_in, N_out), dtype=int)

#     gamma = args.gamma if args.gamma is not None else 0.5
#     gamma_dist = args.gamma_dist if args.gamma_dist is not None else "fixed"

#     M_dist = getattr(args, "M_dist")
#     M_vals = None
#     gammas = []

#     if M_dist in ["spatial_gaussian", "spatial_inversegaussian"]:
#         center_out = (N_out - 1) / 2.0
#         sigma = getattr(args, "M_std", N_out / 7.0) 
#         distances = np.abs(np.arange(N_out) - center_out)
        
#         if M_dist == "spatial_gaussian":
#             weights = np.exp(-0.5 * (distances / sigma)**2)
#         else:
#             max_weight = 1.0
#             min_weight = 0.1
#             weights = max_weight - (max_weight - min_weight) * np.exp(-0.5 * (distances / sigma)**2)
#             weights = np.clip(weights, min_weight, max_weight)
        
#         sum_weights = np.sum(weights)
#         if sum_weights == 0:
#             weights = np.ones(N_out)  
#         scaling_factor = (N_out * base_M) / sum_weights
#         weights *= scaling_factor
        
#         M_vals_float = weights.copy()
#         M_vals = np.floor(M_vals_float).astype(int) 
#         fractional = M_vals_float - M_vals
        
#         remaining = (N_out * base_M) - np.sum(M_vals)
#         if remaining > 0:
#             probs = fractional / np.sum(fractional)
#             extra_indices = np.random.choice(N_out, size=remaining, p=probs, replace=True)
#             np.add.at(M_vals, extra_indices, 1)
        
#         M_vals = np.clip(M_vals, 1, None)

#     for j in range(N_out):
#         if gamma_dist == "fixed":
#             gamma_j = gamma
#         elif gamma_dist == "gaussian":
#             mean_gamma = gamma
#             gamma_std = getattr(args, "gamma_std", gamma*0.1)
#             gamma_j = np.clip(np.random.normal(mean_gamma, gamma_std), 0, 1)
#             gammas.append(gamma_j)
#         elif gamma_dist=="uniform":
#             spread = 0.25
#             gamma_j = np.random.uniform(gamma - spread, gamma + spread)
#         elif gamma_dist == "spatial_gaussian":
#             center_out = (N_out - 1) / 2.0
#             distance = abs(j - center_out)
#             max_distance = center_out
#             normalized_dist = distance / max_distance
#             mean_gamma_j = 1.0 - normalized_dist 
#             gamma_std = getattr(args, "gamma_std", gamma*0.05)
#             gamma_j = np.random.normal(mean_gamma_j, gamma_std)
#             gamma_j = np.clip(gamma_j, 0.0, 1.0)
#         elif gamma_dist == "spatial_inversegaussian":
#             center_out = (N_out - 1) / 2.0
#             distance = abs(j - center_out)
#             max_distance = center_out
#             mean_gamma_j = gamma * (distance / max_distance)
#             gamma_std = getattr(args, "gamma_std", gamma*0.05)
#             gamma_j = np.random.normal(mean_gamma_j, gamma_std)
#             gamma_j = np.clip(gamma_j, 0.0, 1.0)
#         else:
#             raise ValueError("Unknown gamma distribution: {}".format(gamma_dist))
        
#         D_total = connection_counts[j]
#         if M_dist == "fixed":
#             M_j = base_M
#         elif M_dist == "gaussian":
#             M_std = getattr(args, "M_std", base_M /4)
#             M_j = np.random.normal(base_M, M_std / 2)
#             M_j = int(np.round(np.clip(M_j, 1, D_total)))
#         elif M_dist == "uniform":
#             spread = getattr(args, "M_spread", base_M)
#             M_j = np.random.uniform(base_M - spread, base_M + spread)
#             M_j = int(np.round(np.clip(M_j, 1, D_total)))
#         elif M_dist in ["spatial_gaussian", "spatial_inversegaussian"]:
#             M_j = M_vals[j]
#             M_j = max(1, min(M_j, D_total)) 
#         else:
#             raise ValueError(f"Unknown M distribution: {M_dist}")
#         base_window = D_total
#         window_size = int(round(base_window + gamma_j * (N_in - base_window)))
#         window_size = max(D_total, min(window_size, N_in))

#         connections = pick_connections_for_output_node(
#             j, N_in, N_out, sparsity, D_total, M_j, gamma_j, args
#         )
#         for i in connections:
#             adj[i, j] = 1
#     degrees = adj.sum(axis=0)
#     print("Degree stats:", np.min(degrees), np.max(degrees), np.mean(degrees))
#     # --- REWIRING ---
#     if args.random_rewiring != 0:
#         total_edges = int(np.sum(adj))
#         randomness = np.random.binomial(1, p=args.random_rewiring, size=total_edges)
#         count = 0
#         for i in range(N_in):
#             for j in range(N_out):
#                 if adj[i, j] == 1:
#                     if randomness[count] == 1:
#                         adj[i, j] = 0 
#                     count += 1

#         removed_edges = total_edges - int(np.sum(adj))
#         nrAdd = 0
#         while nrAdd < removed_edges:
#             i_rand = np.random.randint(0, N_in)
#             j_rand = np.random.randint(0, N_out)
#             if adj[i_rand, j_rand] == 0:
#                 adj[i_rand, j_rand] = 1
#                 nrAdd += 1
#         print("After rewiring, total edges:", np.sum(adj), "removed:", removed_edges)

#     if w.shape[0] != N_in:
#         return torch.LongTensor(adj).to(w.device).t()
#     return torch.LongTensor(adj).to(w.device)


def create_dendritic_sparse_scheduler(sparsity, w, args):
    """
    Generate a dendritic, spatial, group-structured bipartite adjacency mask
    using full vectorization where possible.

    Parameters:
        sparsity: Fraction of zero elements (1 - sparsity = fraction of ones)
        w: Weight tensor (for shape/device inference)
        args: Namespace with distribution parameters:
            - degree_dist, degree_std, degree_spread, etc.
            - M_dist, M_std, etc.
            - synaptic_dist, synaptic_std, etc.
            - gamma, gamma_dist, gamma_std, etc.
            - random_rewiring (float, 0..1)
    Returns:
        mask (torch.LongTensor): N_in x N_out
    """
    N_in = min(w.shape[0], w.shape[1])
    N_out = max(w.shape[0], w.shape[1])
    base_M = args.M
    total_target = int(round((1 - sparsity) * N_in * N_out))
    rng = np.random.default_rng(args.seed if hasattr(args, "seed") else None)

    # 1. Generate per-column degree (number of input connections for each output)
    D_float = N_in * (1 - sparsity)
    if args.degree_dist == "fixed":
        per_col_degree = np.full(N_out, int(round(D_float)))
    elif args.degree_dist == "gaussian":
        degree_std = getattr(args, "degree_std", 2 * D_float)
        per_col_degree = rng.normal(D_float, degree_std, N_out)
        per_col_degree = np.clip(per_col_degree, 1, N_in)
        per_col_degree = np.round(per_col_degree).astype(int)
    elif args.degree_dist == "uniform":
        spread = getattr(args, "degree_spread", 4 * D_float)
        per_col_degree = rng.uniform(D_float - spread, D_float + spread, N_out)
        per_col_degree = np.clip(per_col_degree, 1, N_in)
        per_col_degree = np.round(per_col_degree).astype(int)
    elif args.degree_dist == "spatial_gaussian":
        center_out = (N_out - 1) / 2.0
        sigma = getattr(args, "degree_std", N_out / 7.0)
        dists = np.abs(np.arange(N_out) - center_out)
        # Peak at center
        vals = np.exp(-0.5 * (dists / sigma) ** 2)
        min_deg = getattr(args, "min_degree", 1)
        max_deg = getattr(args, "max_degree", int(round(D_float)))
        per_col_degree = min_deg + (max_deg - min_deg) * vals
        per_col_degree = np.clip(np.round(per_col_degree), 1, N_in).astype(int)
    elif args.degree_dist == "spatial_inversegaussian":
        center_out = (N_out - 1) / 2.0
        sigma = getattr(args, "degree_std", N_out / 7.0)
        dists = np.abs(np.arange(N_out) - center_out)
        max_dist = np.max(dists)
        # Peak at edges
        vals = np.exp(-0.5 * ((max_dist - dists) / sigma) ** 2)
        min_deg = getattr(args, "min_degree", 1)
        max_deg = getattr(args, "max_degree", int(round(D_float)))
        per_col_degree = min_deg + (max_deg - min_deg) * vals
        per_col_degree = np.clip(np.round(per_col_degree), 1, N_in).astype(int)
    else:
        raise ValueError("Unsupported degree_dist in vectorized version")


    # Adjust to exact total
    def adjust_samples(samples, target_total):
        # Probabilistically up/down-round to hit exact sum (as in your _adjust_samples)
        integer_parts = np.floor(samples).astype(int)
        fractional = samples - integer_parts
        remainder = target_total - np.sum(integer_parts)
        if remainder > 0:
            probs = fractional / np.sum(fractional)
            idx = rng.choice(len(samples), size=remainder, p=probs)
            np.add.at(integer_parts, idx, 1)
        return np.clip(integer_parts, 1, N_in)
    per_col_degree = adjust_samples(per_col_degree, total_target)
    
    # 2. Group structure (per-column number of dendrites/groups)
    if args.M_dist == "fixed":
        M = np.full(N_out, base_M)
    elif args.M_dist == "gaussian":
        M_std = getattr(args, "M_std", base_M / 4)
        M = rng.normal(base_M, M_std, N_out)
        M = np.clip(np.round(M), 1, None).astype(int)
    elif args.M_dist == "uniform":
        spread = getattr(args, "M_spread", base_M)
        M = rng.uniform(base_M - spread, base_M + spread, N_out)
        M = np.clip(np.round(M), 1, None).astype(int)
    elif args.M_dist == "spatial_gaussian":
        center_out = (N_out - 1) / 2.0
        sigma = getattr(args, "M_std", N_out / 7.0)  # default to 1/7th of width
        dists = np.abs(np.arange(N_out) - center_out)
        # Peak at center, falls off with distance
        vals = np.exp(-0.5 * (dists / sigma) ** 2)
        min_M, max_M = getattr(args, "min_M", 1), getattr(args, "max_M", base_M)
        M = min_M + (max_M - min_M) * vals
        M = np.clip(np.round(M), 1, None).astype(int)
    elif args.M_dist == "spatial_inversegaussian":
        center_out = (N_out - 1) / 2.0
        sigma = getattr(args, "M_std", N_out / 7.0)
        dists = np.abs(np.arange(N_out) - center_out)
        max_dist = np.max(dists)
        # Peak at edges, low in center
        vals = np.exp(-0.5 * ((max_dist - dists) / sigma) ** 2)
        min_M, max_M = getattr(args, "min_M", 1), getattr(args, "max_M", base_M)
        M = min_M + (max_M - min_M) * vals
        M = np.clip(np.round(M), 1, None).astype(int)
    else:
        raise ValueError("Unsupported M_dist in vectorized version")

    # 3. Gamma structure (window spreading, per column)
    if args.gamma_dist == "fixed":
        gamma = np.full(N_out, args.gamma)
    elif args.gamma_dist == "gaussian":
        gamma_std = getattr(args, "gamma_std", args.gamma*0.1)
        gamma = rng.normal(args.gamma, gamma_std, N_out)
        gamma = np.clip(gamma, 0, 1)
    elif args.gamma_dist == "uniform":
        spread = getattr(args, "gamma_spread", 0.25)
        gamma = rng.uniform(args.gamma - spread, args.gamma + spread, N_out)
        gamma = np.clip(gamma, 0, 1)
    elif args.gamma_dist == "spatial_gaussian":
        center_out = (N_out - 1) / 2.0
        sigma = getattr(args, "gamma_std", N_out / 7.0)
        dists = np.abs(np.arange(N_out) - center_out)
        gamma = np.exp(-0.5 * (dists / sigma) ** 2)
        # Optionally rescale to (min_gamma, max_gamma)
        min_g, max_g = getattr(args, "min_gamma", 0.0), getattr(args, "max_gamma", 1.0)
        gamma = min_g + (max_g - min_g) * gamma
    elif args.gamma_dist == "spatial_inversegaussian":
        center_out = (N_out - 1) / 2.0
        sigma = getattr(args, "gamma_std", N_out / 7.0)
        dists = np.abs(np.arange(N_out) - center_out)
        max_dist = np.max(dists)
        # Inverse Gaussian: further from center = smaller gamma
        gamma = np.exp(-0.5 * ((max_dist - dists) / sigma) ** 2)
        min_g, max_g = getattr(args, "min_gamma", 0.0), getattr(args, "max_gamma", 1.0)
        gamma = min_g + (max_g - min_g) * gamma
    else:
        raise ValueError("Unsupported gamma_dist in vectorized version")

    # 4. For each output, calculate:
    #   - base_center (where its window is centered in input space)
    if args.degree_dist in ["spatial_gaussian", "spatial_inversegaussian"] or \
    args.M_dist in ["spatial_gaussian", "spatial_inversegaussian"] or \
    args.gamma_dist in ["spatial_gaussian", "spatial_inversegaussian"]:
        # For spatial_gaussian: peak in the middle
        # For spatial_inversegaussian: peaks at edges
        if args.degree_dist == "spatial_gaussian":
            input_anchors = np.full(N_out, N_in // 2)
        elif args.degree_dist == "spatial_inversegaussian":
            # Linearly interpolate between 0 and N_in-1 across N_out
            input_anchors = np.linspace(0, N_in-1, N_out)
        else:
            # Default: linear mapping as before
            input_anchors = np.arange(N_out) * (N_in / N_out)
        center = np.round(input_anchors).astype(int) % N_in
    else:
        base_center = np.arange(N_out) * (N_in / N_out)
        center = np.round(base_center).astype(int) % N_in

    #   - window size for each column
    window_size = np.round(per_col_degree + gamma * (N_in - per_col_degree)).astype(int)
    window_size = np.clip(window_size, per_col_degree, N_in)

    # 5. For each output neuron, partition its degree into groups
    # (this part is still per-column, but is fast and vectorized over arrays)
    group_sizes = []
    group_centers = []
    for j in range(N_out):
        m_j = M[j]
        d_j = per_col_degree[j]
        # Assign group sizes by synaptic_dist
        if args.synaptic_dist == "fixed":
            base = d_j // m_j
            sizes = np.full(m_j, base)
            sizes[:d_j % m_j] += 1
        elif args.synaptic_dist == "uniform":
            mean = d_j / m_j
            spread = getattr(args, "synaptic_spread", 2)
            low = max(1, mean * (1 - spread))
            high = min(mean * (1 + spread), d_j)
            samples = rng.uniform(low, high, m_j)
            sizes = adjust_samples(samples, d_j)
        elif args.synaptic_dist == "gaussian":
            mean = d_j / m_j
            std = getattr(args, "synaptic_std", mean)
            samples = rng.normal(mean, std, m_j)
            samples = np.clip(samples, 1, d_j)
            sizes = adjust_samples(samples, d_j)
        elif args.synaptic_dist == "spatial_gaussian":
            center_group = (m_j - 1) / 2.0
            sigma = getattr(args, "synaptic_std", m_j / 3.0)
            dists = np.abs(np.arange(m_j) - center_group)
            vals = np.exp(-0.5 * (dists / sigma) ** 2)
            min_size = getattr(args, "min_synaptic", 1)
            max_size = getattr(args, "max_synaptic", int(np.ceil(d_j / m_j)))
            group_sizes_f = min_size + (max_size - min_size) * vals
            samples = np.clip(group_sizes_f, 1, d_j)
            sizes = adjust_samples(samples, d_j)
        elif args.synaptic_dist == "spatial_inversegaussian":
            center_group = (m_j - 1) / 2.0
            sigma = getattr(args, "synaptic_std", m_j / 3.0)
            dists = np.abs(np.arange(m_j) - center_group)
            max_dist = np.max(dists)
            vals = np.exp(-0.5 * ((max_dist - dists) / sigma) ** 2)
            min_size = getattr(args, "min_synaptic", 1)
            max_size = getattr(args, "max_synaptic", int(np.ceil(d_j / m_j)))
            group_sizes_f = min_size + (max_size - min_size) * vals
            samples = np.clip(group_sizes_f, 1, d_j)
            sizes = adjust_samples(samples, d_j)
        else:
            raise ValueError("Unsupported synaptic_dist in vectorized version")
        group_sizes.append(sizes)

        # Calculate group centers (equally spread around window)
        win_sz = window_size[j]
        c = center[j]
        win_start = c - win_sz // 2
        # uniform_centers is a float vector, then converted to int modulo N_in
        group_center_j = ((win_start + np.linspace(0, win_sz, m_j, endpoint=False) + win_sz/(2*m_j)).round().astype(int)) % N_in
        group_centers.append(group_center_j)
    group_sizes = np.array(group_sizes)  # N_out x max(M)
    group_centers = np.array(group_centers)  # N_out x max(M)

    # 6. Assign connections within windows/groups (vectorized)
    adj = np.zeros((N_in, N_out), dtype=int)
    for j in range(N_out):
        for g in range(M[j]):
            gc = group_centers[j, g]
            gs = group_sizes[j, g]
            win_sz = window_size[j]
            # Assign `gs` contiguous positions centered at gc, modulo N_in
            half = gs // 2
            group_indices = [(gc - half + k) % N_in for k in range(gs)]
            adj[group_indices, j] = 1
        # Post-processing for uniqueness, fill deficit if needed
        if adj[:, j].sum() < per_col_degree[j]:
            # Sample additional positions in window that are not already set
            existing = set(np.flatnonzero(adj[:, j]))
            possible = set((center[j] + np.arange(-window_size[j]//2, window_size[j]//2+1)) % N_in)
            available = list(possible - existing)
            need = per_col_degree[j] - adj[:, j].sum()
            if len(available) < need:
                available = list(set(range(N_in)) - existing)
            fill = rng.choice(available, need, replace=False)
            adj[fill, j] = 1
        # If too many assigned due to overlap, randomly turn off
        elif adj[:, j].sum() > per_col_degree[j]:
            over = int(adj[:, j].sum() - per_col_degree[j])
            on_idx = np.flatnonzero(adj[:, j])
            drop = rng.choice(on_idx, over, replace=False)
            adj[drop, j] = 0

    # 7. Optional random rewiring step (preserves total number of edges)
    if getattr(args, "random_rewiring", 0) > 0:
        edge_locs = np.argwhere(adj == 1)
        n_rewire = int(getattr(args, "random_rewiring", 0) * len(edge_locs))
        rewire_idx = rng.choice(len(edge_locs), n_rewire, replace=False)
        # Turn off those edges
        for idx in rewire_idx:
            i, j = edge_locs[idx]
            adj[i, j] = 0
        # Add new edges elsewhere to maintain total edge count
        added = 0
        while added < n_rewire:
            i = rng.integers(0, N_in)
            j = rng.integers(0, N_out)
            if adj[i, j] == 0:
                adj[i, j] = 1
                added += 1

    mask = torch.LongTensor(adj)
    if w.shape[0] != N_in:
        mask = mask.t()
    return mask.to(w.device)



def create_ws_sparse_scheduler(sparsity, w, args):
    indim = min(w.shape[0], w.shape[1])
    outdim = max(w.shape[0], w.shape[1])

    K = (1- sparsity) * indim
    
    K1 = int(K)
    K2 = int(K) + 1
    my_list = [K1] * int(outdim * (K2 - K)) + [K2] * int(outdim * (K-K1) + 1)
    random.shuffle(my_list)
    
    adj = np.zeros((indim, outdim))

    rate = indim/outdim
    random.shuffle(my_list)
    for i in range(outdim):
        idx = [(int(i*rate-my_list[i]/2) + j) % indim for j in range(my_list[i])]
        adj[idx, i] = 1
        
    # rewiring
    if args.random_rewiring != 0:
        randomness = np.random.binomial(1, p=args.random_rewiring, size=int(np.sum(adj)))
        # print(randomness)
        count = 0
        for i in range(indim):
            for j in range(outdim):
                if adj[i][j] == 1:
                    if randomness[count] == 1:
                        adj[i][j] = 0
                    
                    count += 1
        
        # regrow
        noRewires = int((1-sparsity) * indim * outdim) - np.sum(adj)
        nrAdd = 0
        while (nrAdd < noRewires):
            i = np.random.randint(0, indim)
            j = np.random.randint(0, outdim)
            if adj[i][j] == 0:
                nrAdd += 1
                adj[i][j] = 1
        
        print(np.sum(adj), noRewires)
    if w.shape[0] != indim:
        return torch.LongTensor(adj).to(w.device).t()

    return torch.LongTensor(adj).to(w.device)
    # layer.weight_mask = torch.LongTensor(adj).to(layer.device)


def create_cws_sparse_scheduler(sparsity, w, args):
    indim, outdim = min(w.shape), max(w.shape)
    device = w.device

    # 1. ring-lattice ---------------------------------------------------
    k          = (1 - sparsity) * indim
    k_low      = int(np.floor(k))
    k_high     = k_low + 1
    n_high     = int(round(outdim * (k - k_low)))
    degrees    = np.array([k_high]*n_high + [k_low]*(outdim-n_high))
    np.random.shuffle(degrees)

    adj = np.zeros((indim, outdim), np.uint8)
    ratio = indim/outdim
    for col, d in enumerate(degrees):
        start = int(round(col*ratio - d/2))
        adj[(start + np.arange(d)) % indim, col] = 1

    # 2. drop edges -----------------------------------------------------
    p = getattr(args, "random_rewiring", 0.0)
    if p > 0:
        e      = np.argwhere(adj)            # (N,2)
        keep   = np.random.rand(len(e)) > p
        adj[:] = 0
        adj[e[keep,0], e[keep,1]] = 1

    # 3. global Gaussian regrowth --------------------------------------
    target  = int((1-sparsity)*indim*outdim)
    missing = target - int(adj.sum())
    if missing:
        sx, sy = args.sigma_x, args.sigma_y   # already in **cells**
        rho    = args.rho
        cov    = np.array([[sx**2, rho*sx*sy],
                           [rho*sx*sy, sy**2]])
        mean   = np.array([indim/2, outdim/2])

        batch  = 2048
        stuck  = 0
        while missing:
            pts  = np.random.multivariate_normal(mean, cov, batch)
            ii   = np.mod(np.rint(pts[:,0]).astype(int), indim)
            jj   = np.mod(np.rint(pts[:,1]).astype(int), outdim)

            flat = np.unique(ii*outdim + jj)          # drop duplicates
            free = flat[adj.ravel()[flat] == 0]       # keep only zeros
            take = free[:missing]
            adj.ravel()[take] = 1
            gained   = take.size
            missing -= gained

            # adapt batch / σ if progress is poor
            if gained < 0.01*batch:
                stuck += 1
            else:
                stuck  = 0
            if stuck >= 5:
                sx *= 1.5; sy *= 1.5
                cov[0,0], cov[1,1] = sx**2, sy**2
                cov[0,1] = cov[1,0] = rho*sx*sy
                stuck, batch = 0, min(batch*2, 65536)

            # fallback to uniform for the very last few %
            if missing and missing < 0.05*target*p:
                zeros = np.where(adj.ravel() == 0)[0]
                picks = np.random.choice(zeros, missing, replace=False)
                adj.ravel()[picks] = 1
                missing = 0

    mask = torch.from_numpy(adj).long().to(device)
    return mask.t() if w.shape[0] != indim else mask

def create_ws1_sparse_scheduler(sparsity, w, args):
    device = w.device
    indim = min(w.shape[0], w.shape[1])
    outdim = max(w.shape[0], w.shape[1])

    K = (1- sparsity) * indim
    
    K1 = int(K)
    K2 = int(K) + 1
    my_list = [K1] * int(outdim * (K2 - K)) + [K2] * int(outdim * (K-K1) + 1)
    random.shuffle(my_list)
    
    adj = np.zeros((indim, outdim))

    rate = indim/outdim
    random.shuffle(my_list)
    for i in range(outdim):
        idx = [(int(i*rate-my_list[i]/2) + j) % indim for j in range(my_list[i])]
        adj[idx, i] = 1
        
    # rewiring
    if args.random_rewiring > 0:
        target_links = int(adj.sum())
        alpha = None
        if args.delta < 1.0:
            alpha = args.delta / (1.0 - args.delta)          # exponent ≥ 0

        for j in range(outdim):                              # one *column*
            col_links = np.flatnonzero(adj[:, j])            # current 1-positions
            n_rewire  = int(len(col_links) * args.random_rewiring)
            if n_rewire == 0:
                continue

            # 1) disconnect a random subset of existing links -------------------
            victims = np.random.choice(col_links, size=n_rewire, replace=False)
            adj[victims, j] = 0

            # 2) re-grow them using S_ij-weighted sampling ----------------------
            for _ in range(n_rewire):
                free = np.flatnonzero(adj[:, j] == 0)
                if free.size == 0:
                    break  # column saturated – shouldn’t happen

                if args.delta == 1.0:
                    # δ = 1  ⇒ deterministically pick row(s) closest to diagonal
                    d  = np.abs(free - (j % indim))
                    d  = np.minimum(d, indim - d)            # ring distance
                    new_i = free[d == d.min()][0]            # pick first min-d row
                else:
                    # δ ∈ (0,1)  ⇒ sample with probability  S_ij^{-α}
                    d  = np.abs(free - (j % indim))
                    d  = np.minimum(d, indim - d)
                    S  = 1.0 + d
                    w  = S ** (-alpha)                       # weight ↓ as distance ↑
                    p  = w / w.sum()
                    new_i = np.random.choice(free, p=p)

                adj[new_i, j] = 1                           # add new link

        # (optional debug)
        assert adj.sum() == target_links, \
               "Degree changed during rewiring – check logic"
        
    mask = torch.from_numpy(adj).long().to(device)
    return mask.t() if w.shape[0] != indim else mask


def create_ws2_sparse_scheduler(sparsity, w, args):
    device = w.device
    indim  = min(w.shape[0], w.shape[1])
    outdim = max(w.shape[0], w.shape[1])

    # ---------- 1. build the deterministic ring-band sparsity  ----------
    K        = (1 - sparsity) * indim
    K1, K2   = int(K), int(K) + 1
    my_list  = [K1] * int(outdim * (K2 - K)) + [K2] * int(outdim * (K - K1) + 1)
    random.shuffle(my_list)

    adj      = np.zeros((indim, outdim), dtype=np.int8)
    rate     = indim / outdim
    for i in range(outdim):
        idx = [(int(i * rate - my_list[i] / 2) + j) % indim for j in range(my_list[i])]
        adj[idx, i] = 1

    # ---------- 2. remember that initial band so we can block it later  ----------
    band_mask = adj.astype(bool)        # True == position was on the original band

    # ---------- 3. random rewiring, but only to off-band rows  ----------
    if args.random_rewiring > 0:
        target_links = int(adj.sum())    # for the final assert
        alpha = None
        if args.delta < 1.0:
            alpha = args.delta / (1.0 - args.delta)   # α ≥ 0

        for j in range(outdim):                       # work column-wise
            col_links = np.flatnonzero(adj[:, j])
            n_rewire  = int(len(col_links) * args.random_rewiring)
            if n_rewire == 0:
                continue

            # 1) cut a random subset of existing links
            victims = np.random.choice(col_links, size=n_rewire, replace=False)
            adj[victims, j] = 0

            # 2) add the same number of off-band links
            for _ in range(n_rewire):
                free = np.flatnonzero( (adj[:, j] == 0) & (~band_mask[:, j]) )
                if free.size == 0:          # all off-band rows already filled
                    break                  # -> skip, keeps sparsity but may drop a link

                if args.delta == 1.0:
                    d      = np.abs(free - (j % indim))
                    d      = np.minimum(d, indim - d)       # ring distance
                    new_i  = free[d == d.min()][0]          # first closest row
                else:
                    d      = np.abs(free - (j % indim))
                    d      = np.minimum(d, indim - d)
                    S      = 1.0 + d
                    w_prob = S ** (-alpha)
                    p      = w_prob / w_prob.sum()
                    new_i  = np.random.choice(free, p=p)

                adj[new_i, j] = 1

        # (optional sanity-check)
        assert adj.sum() == target_links, \
               "Degree changed during rewiring – check logic"

    mask = torch.from_numpy(adj).long().to(device)
    return mask.t() if w.shape[0] != indim else mask

def create_ws3_sparse_scheduler(sparsity, w, args):
    delta = args.delta
    if not (0.0 <= sparsity < 1.0):
        raise ValueError("'sparsity' must be in [0, 1).")
    if not (0.0 <= delta <= 1.0):
        raise ValueError("'delta' must be in [0, 1].")

    device = w.device
    in_dim  = min(w.shape[0], w.shape[1])
    out_dim = max(w.shape[0], w.shape[1])

    K_float = (1.0 - sparsity) * out_dim
    total_links = int(round(K_float * in_dim))

    # build row_degrees according to degree_dist
    if getattr(args, "degree_dist", None) == "uniform":
        # start each row with 1 link
        if total_links < in_dim:
            raise ValueError("Total links < number of rows: cannot give each row degree 1")
        # remaining links to distribute
        rem = total_links - in_dim
        # initialize degrees
        row_degrees = [1] * in_dim
        # uniformly pick rows to add the remaining links
        for _ in range(rem):
            i = random.randrange(in_dim)
            row_degrees[i] += 1
    else:
        # original “rounded” distribution
        K1, K2 = int(K_float), int(K_float) + 1
        # how many rows get K2 vs K1 so that sum = K_float * in_dim
        n_K2 = int(round(in_dim * (K_float - K1))) 
        n_K1 = in_dim - n_K2
        row_degrees = [K1]*n_K1 + [K2]*n_K2
        random.shuffle(row_degrees)

    # prepare distance-weight exponent
    alpha = None if delta == 0.0 else (1.0 - delta) / delta

    # build adjacency by sampling each row
    adj = np.zeros((in_dim, out_dim), dtype=np.int8)
    for i, Ki in enumerate(row_degrees):
        if Ki <= 0:
            continue

        center = (i* (out_dim // in_dim)) % out_dim 
        distances = np.abs(np.arange(out_dim) - center)
        distances = np.minimum(distances, out_dim - distances)  # ring distance

        if delta == 0.0:
            # pick the Ki closest columns
            chosen = np.argsort(distances)[:Ki]
        else:
            # probabilistic: sample without replacement with weight ∝ (1 + d)^(-α)
            S = 1.0 + distances
            weights = S ** (-alpha)
            weights = weights / weights.sum()
            chosen = np.random.choice(out_dim, size=Ki, replace=False, p=weights)

        adj[i, chosen] = 1

    mask = torch.from_numpy(adj).long().to(device)
    # transpose if w was “transposed” relative to out_dim
    return mask.t() if w.shape[1] != out_dim else mask





def create_ws_cross_scheduler(sparsity, w, args):
    """
    Watts-Strogatz 'cross' variant:
      – Same expected in-degree (K) as the original scheduler
      – Half of the neighbours lie on the main diagonal, the other half on the anti-diagonal
      – Rewiring and regrowth behave exactly as before
    """
    device = w.device
    indim  = min(w.shape[0], w.shape[1])
    outdim = max(w.shape[0], w.shape[1])

    # ---------- degree distribution ----------------------------------------------------------
    K_float = (1.0 - sparsity) * indim
    K1, K2  = int(K_float), int(K_float) + 1
    my_list = np.empty(outdim, dtype=np.int16)
    my_list[: int(outdim * (K2 - K_float))]         = K1
    my_list[int(outdim * (K2 - K_float)) :]         = K2
    np.random.shuffle(my_list)

    # ---------- make the two diagonals -------------------------------------------------------
    # offsets around each centre: [-k//2, …, 0, … k//2]
    max_k = my_list.max()          # small (≲ dozens), good for broadcasting
    offsets = np.arange(max_k) - max_k // 2         # shape (max_k,)

    # per-column centre of each diagonal
    cols = np.arange(outdim)
    rate = indim / outdim
    c_main  = (cols * rate).astype(np.int64)           # ↘︎
    c_cross = (indim - 1 - cols * rate).astype(np.int64)  # ↙︎

    # build row indices with broadcasting, then mod & clip
    idx_main  = (c_main[:, None]  + offsets[None, :]) % indim  # (outdim, max_k)
    idx_cross = (c_cross[:, None] + offsets[None, :]) % indim

    # create dense mask on the CPU first
    adj = np.zeros((indim, outdim), dtype=np.int8)

    # we cannot write whole (outdim, max_k) blocks because each column has its own k
    for col, k in enumerate(my_list):
        if k == 0:
            continue
        k_main  = (k + 1) // 2
        k_cross =  k       // 2
        adj[idx_main[col, :k_main],   col] = 1
        adj[idx_cross[col, :k_cross], col] = 1

    nnz = adj.sum()

    # ---------- rewiring ---------------------------------------------------------------------
    if args.random_rewiring:
        # 1. pick the edges to drop -------------------------------------------
        ones_r, ones_c = np.nonzero(adj)
        keep = np.random.rand(nnz) >= args.random_rewiring
        adj[ones_r[~keep], ones_c[~keep]] = 0
        nnz = keep.sum()

        # 2. regrow anywhere we still have zeros --------------------------------
        target = int((1.0 - sparsity) * indim * outdim)
        missing = target - nnz
        if missing > 0:
            zeros_r, zeros_c = np.nonzero(adj == 0)          # all zero cells
            pick = np.random.choice(len(zeros_r), missing, replace=False)
            adj[zeros_r[pick], zeros_c[pick]] = 1
            nnz = target                                     # now full

    # ---------- tensor to the right shape and device -------------------------
    adj_t = torch.from_numpy(adj).to(device, dtype=torch.long)
    return adj_t.t() if w.shape[0] != indim else adj_t



def spatial_sort_order(N):
    """
    Return a list of length N giving the “center→halves→quarters→…” positions:
    e.g. for N=8 → [4,2,5,1,3,6,0,7]
    """
    positions = []
    k = 1
    Nminus1 = N - 1
    while len(positions) < N:
        for m in range(k):
            pos = round((2*m + 1) / (2*k) * Nminus1)
            if pos not in positions:
                positions.append(pos)
                if len(positions) == N:
                    break
        k *= 2
    return positions

def create_BHI_sparse_scheduler(sparsity, w, args, is_last_layer=False):
    Ni, No = w.shape
    transposed = False
    # Ensure N_out >= N_in
    if No < Ni or (No % Ni) != 0:
        Ni, No = No, Ni
        transposed = True

    adj, *_ = nPSO_bipartite(
        Ni, No,
        sparsity,
        args.BHI_T,
        args.BHI_gamma,
        args.BHI_distr,
        args.rewire_mode
    )

    # get dense matrix
    if hasattr(adj, "toarray"):
        mat = adj.toarray().astype(int)
    else:
        mat = np.array(adj, dtype=int)

    def spatial_sort_order(N):
        order = []
        queue = [(0, N)]
        while queue and len(order) < N:
            start, end = queue.pop(0)
            length = end - start
            if length <= 0:
                continue
            mid = start + (length - 1) // 2
            order.append(mid)
            queue.append((start, mid))
            queue.append((mid + 1, end))
        return order

    if args.degree_allocation:
        # 1) permute columns by input‐degree
        deg_in = mat.sum(axis=0)
        sorted_in = np.argsort(-deg_in)
        pos_in = spatial_sort_order(No)
        col_perm = np.empty(No, dtype=int)
        for idx, old in enumerate(sorted_in):
            target = pos_in[idx] if idx < len(pos_in) else idx
            col_perm[target] = old
        col_perm = col_perm.astype(int)
        mat = mat[:, col_perm]

        # 2) permute rows by output‐degree if not last layer
        if not is_last_layer:
            deg_out = mat.sum(axis=1)
            sorted_out = np.argsort(-deg_out)
            pos_out = spatial_sort_order(Ni)
            row_perm = np.empty(Ni, dtype=int)
            for idx, old in enumerate(sorted_out):
                target = pos_out[idx] if idx < len(pos_out) else idx
                row_perm[target] = old
            row_perm = row_perm.astype(int)
            mat = mat[row_perm, :]

    mask = torch.from_numpy(mat.astype(np.int64))
    if transposed:
        mask = mask.t()
    return mask.to(w.device)


def create_QHI_sparse_scheduler(sparsity, W, args):
    # W: list of weight matrices [W0, W1, W2]
    # infer the four layer sizes A, B, C, D
    A = W[0].shape[1]
    B = W[0].shape[0]
    C = W[1].shape[0]
    D = W[2].shape[0]

    # map args.BHI_distr to quadpartite distr
    distr = 'random' if args.BHI_distr == 0 else 'community'

    # build all three bipartite masks at once
    x_DC, x_CB, x_BA, coords_A, coords_B, coords_C, coords_D = nPSO_quadpartite(
        A, B, C, D,
        sparsity,
        args.BHI_T,
        args.BHI_gamma,
        distr
    )

    # convert to torch masks matching W's devices
    mask0 = torch.LongTensor(x_BA).to(W[0].device)  # for W0 (B×A)
    mask1 = torch.LongTensor(x_CB).to(W[1].device)  # for W1 (C×B)
    mask2 = torch.LongTensor(x_DC).to(W[2].device)  # for W2 (D×C)

    return [mask0, mask1, mask2]




def generate_barabasi_alberta_graph(N, m):
    adj = np.zeros((N, N))
    if not isinstance(m, int):
        print("m is not an integer")
        m1 = int(m)
        m2 = int(m) + 1
        
        adj[:m2, :m2] = np.triu(np.ones((m2, m2)), k=1) + np.triu(np.ones((m2, m2)), k=1).T
        my_list = [m1] * int((N-m2) * (m2 - m)) + [m2] * int((N-m2) * (m-m1) + 1)
        random.shuffle(my_list)
        for i in range(m2, N):
            targets = np.arange(i)
            p_normalized = np.sum(adj[:i, :i], axis=1) / np.sum(np.sum(adj[:i, :i], axis=1))

            
            m_tmp = my_list[i-m2]
            idx = np.random.choice(targets, size=m_tmp, replace=False, p=p_normalized)
            adj[i, idx] = 1
            adj[idx, i] = 1
            
    return adj 

def create_ba_sparse(layer):
    m = (1- layer.sparsity) * layer.indim * layer.outdim / (layer.indim + layer.outdim)
    adj = generate_barabasi_alberta_graph(layer.indim + layer.outdim, m)
    nodes = list(range(layer.indim + layer.outdim))
    random.shuffle(nodes)
    
    layer_N = list(set(nodes[:layer.indim]))
    layer_M = list(set(nodes[layer.indim:]))

    adj = adj[layer_N+layer_M]
    adj = adj[:, layer_N+layer_M]
    
    adj = np.triu(adj, k=1)
    final_adj = adj[:layer.indim, layer.indim:]
    layer1_fru = np.array(adj[:layer.indim, :layer.indim].nonzero()).reshape(-1)
    layer2_fru = np.array(adj[layer.indim:, layer.indim:].nonzero()).reshape(-1)
    np.random.shuffle(layer1_fru)
    np.random.shuffle(layer2_fru)

    if len(layer1_fru) <= len(layer2_fru):
        # print("in")
        layer2_fru_flag = np.zeros_like(layer2_fru)
        for i in range(len(layer1_fru)):
            for j in range(len(layer2_fru)):
                if layer2_fru_flag[j] == 0:
                    if final_adj[layer1_fru[i], layer2_fru[j]]:
                        continue
                    else:
                        final_adj[layer1_fru[i], layer2_fru[j]] = 1
                        layer2_fru_flag[j] = 1
                        break
                else:
                    continue
             
        zero_indices = np.where(layer2_fru_flag == 0)[0]
        layer2_r = layer2_fru[zero_indices]
        for i in range(len(layer2_r)//2):
            node_degrees = np.sum(final_adj, axis=1)
            prob_distribution = node_degrees / np.sum(node_degrees)
            while True:
                target_node = np.random.choice(np.arange(layer.indim), size=1, replace=False, p=prob_distribution)
                if final_adj[target_node, layer2_r[i]] == 1:
                    continue
                else:
                    final_adj[target_node, layer2_r[i]] = 1
                    break
                
    elif len(layer1_fru) > len(layer2_fru):
        layer1_fru_flag = np.zeros_like(layer1_fru)
        for i in range(len(layer2_fru)):
            for j in range(len(layer1_fru)):
                if layer1_fru_flag[j] == 0:
                    if final_adj[layer1_fru[j], layer2_fru[i]]:
                        continue
                    else:
                        final_adj[layer1_fru[j], layer2_fru[i]] = 1
                        layer1_fru_flag[j] = 1
                        break
                else:
                    continue
        zero_indices = np.where(layer1_fru_flag == 0)[0]
        layer1_r = layer1_fru[zero_indices]
        for i in range(len(layer1_r)//2):
            node_degrees = np.sum(final_adj, axis=0)
            prob_distribution = node_degrees / np.sum(node_degrees)
            while True:
                target_node = np.random.choice(np.arange(layer.outdim), size=1, replace=False, p=prob_distribution)
                if final_adj[layer1_r[i], target_node] == 1:
                    continue
                else:
                    final_adj[layer1_r[i], target_node] = 1
                    break
    print(int(np.sum(final_adj)))
    layer.weight_mask = torch.Tensor(final_adj).to(layer.device)
    
def create_self_correlated_sparse(model, corr, dim, soft=False, noise=False):
    isnan = np.isnan(corr)
    corr[isnan] = 0
    for i in range(corr.shape[0]):
        corr[i, i] = 0
    
    if noise:
        corr += np.random.randn(corr.shape[0], corr.shape[1])
    # 1x of the dimension
    if dim == 1:
        for i in range(len(model.sparse_layers)):
            number_of_links = model.sparse_layers[i].n_params
            update_topology(model.sparse_layers[i], corr, number_of_links, soft)

    # 2x of the dimension
    elif dim == 2:
        dimension = corr.shape[0] * 2
        expanded_dimension = np.zeros((dimension, dimension))
        expanded_dimension[:dimension//2, :dimension//2] = corr
        expanded_dimension[:dimension//2, dimension//2:] = corr
        expanded_dimension[dimension//2:, :dimension//2] = corr
        expanded_dimension[dimension//2:, dimension//2:] = corr
        
        for i in range(len(model.sparse_layers)):
            number_of_links = model.sparse_layers[i].n_params
            if i == 0:
                first_layer = expanded_dimension[:dimension//2, :].copy()
                update_topology(model.sparse_layers[i], first_layer, number_of_links, soft)
            else:
                update_topology(model.sparse_layers[i], expanded_dimension, number_of_links, soft)
    
def update_topology(layer, corr, number_of_links, soft=False):
    adj = torch.zeros_like(torch.Tensor(corr))
    corr_flatten = torch.abs(torch.Tensor(corr).flatten())
    if soft:
        probabilities = corr_flatten / corr_flatten.sum()
        sampled_flat_indices = torch.multinomial(probabilities, number_of_links, replacement=False)
        adj = adj.reshape(-1)
        adj[sampled_flat_indices] = 1
        adj = adj.reshape(corr.shape[0], corr.shape[1])
    else:
        threshold = torch.abs(torch.sort(-torch.abs(corr_flatten))[0][number_of_links-1])
        corr = torch.Tensor(corr)
        adj[torch.abs(corr)>=threshold]=1
        adj[torch.abs(corr)<threshold]=0

    layer.weight_mask = adj.to(layer.device)


def create_self_correlated_scheduler(model, corr, dim):
    isnan = np.isnan(corr)
    corr[isnan] = 0
    for i in range(corr.shape[0]):
        corr[i, i] = 0

    # 1x of the dimension
    if dim == 1:
        for i in range(len(model.sparse_layers)):
            number_of_links = model.sparse_layers[i].n_params
            update_topology(model.sparse_layers[i], corr, number_of_links)

    # 2x of the dimension
    elif dim == 2:
        dimension = corr.shape[0] * 2
        expanded_dimension = np.zeros((dimension, dimension))
        expanded_dimension[:dimension//2, :dimension//2] = corr
        expanded_dimension[:dimension//2, dimension//2:] = corr
        expanded_dimension[dimension//2:, :dimension//2] = corr
        expanded_dimension[dimension//2:, dimension//2:] = corr
        
        for i in range(len(model.sparse_layers)):
            number_of_links = model.sparse_layers[i].n_params
            if i == 0:
                first_layer = expanded_dimension[:dimension//2, :].copy()
                update_topology(model.sparse_layers[i], first_layer, number_of_links)
            else:
                update_topology(model.sparse_layers[i], expanded_dimension, number_of_links)

def update_topology_scheduler(w, corr, number_of_links):
    adj = torch.zeros_like(torch.Tensor(corr))
    corr_flatten = torch.abs(torch.Tensor(corr).flatten())
    print('size:',corr_flatten.size())
    print(f'num of links:{ number_of_links}')
    threshold = torch.abs(torch.sort(-torch.abs(corr_flatten))[0][number_of_links])
    print('sorted')
    corr = torch.Tensor(corr)
    adj[torch.abs(corr)>=threshold]=1
    adj[torch.abs(corr)<threshold]=0
    # print(number_of_links)
    # print(torch.sum(adj))
    return adj.to(w.device)

def update_topology_scheduler_soft(w, corr, number_of_links):
    corr = torch.Tensor(corr)
    adj = torch.zeros_like(corr)
    corr_abs = torch.abs(corr)
    corr_flatten = corr_abs.flatten()
    probabilities = corr_flatten / corr_flatten.sum()
    sampled_indices = torch.multinomial(probabilities, number_of_links, replacement=False)
    adj_flatten = adj.flatten()
    adj_flatten[sampled_indices] = 1
    return adj_flatten.view(corr.shape).to(w.device)