import numpy as np
import torch
import torch.distributed as dist
import os
import sys
sys.path.append("../../")
sys.path.append("../../mlp")
from dst_util import get_W
from sparse_topology_initialization import create_ws_sparse_scheduler, create_cws_sparse_scheduler, create_dendritic_sparse_scheduler, create_BHI_sparse_scheduler, create_QHI_sparse_scheduler, create_ws_cross_scheduler, create_ws1_sparse_scheduler, create_ws2_sparse_scheduler, create_ws3_sparse_scheduler
import math
from scipy.sparse import csr_matrix
import CH_scores
from scipy.io import loadmat, savemat
from sparse_topology_initialization import update_topology_scheduler, update_topology_scheduler_soft

from torchvision import datasets, transforms
from load_data import TinyImageNet_load,TransformedSubset

def custom_multinomial(weights, num_samples, replacement=False, device='cpu'):
    """
    Custom multinomial sampling to bypass the 2^24 element limitation.
    Uses numpy.random.choice to handle large weights tensors.
    """
    weights_np = weights.cpu().numpy()  # Ensure weights are on CPU for numpy
    probabilities = weights_np / np.sum(weights_np)  # Normalize weights
    sampled_indices = np.random.choice(
        range(len(weights)), size=num_samples, p=probabilities, replace=replacement
    )
    # Return the tensor on the specified device
    return torch.from_numpy(sampled_indices).to(device)

def remove_unactive_links_backward(current_adj, after_adj):
    outdegree = torch.sum(after_adj, dim=0)
    outdegree[outdegree>0] = 1
    current_num = torch.sum(current_adj)
    current_adj = current_adj * outdegree.reshape(-1, 1)
    print("Number of removed unactive links backwards: ", int(current_num - torch.sum(current_adj)))

    return current_adj

def remove_unactive_links_forward(current_adj, before_adj):
    indegree = torch.sum(before_adj, dim=1)
    indegree[indegree>0] = 1
    current_num = torch.sum(current_adj)
    current_adj = current_adj * indegree.reshape(1, -1)
    print("Number of removed unactive links forwards: ", int(current_num - torch.sum(current_adj)))
    return current_adj




class IndexMaskHook:
    def __init__(self, layer, scheduler):
        self.layer = layer
        self.scheduler = scheduler
        self.dense_grad = None

    def __name__(self):
        return 'IndexMaskHook'

    @torch.no_grad()
    def __call__(self, grad):
        mask = self.scheduler.backward_masks[self.layer]

        # only calculate dense_grads when necessary
        if self.scheduler.check_if_backward_hook_should_accumulate_grad():
            if self.dense_grad is None:
                # initialize as all 0s so we can do a rolling average
                self.dense_grad = torch.zeros_like(grad)
            self.dense_grad += grad / self.scheduler.grad_accumulation_n
        else:
            self.dense_grad = None

        # print(f"Layer {self.layer}: mask sparsity is {torch.sum(mask).item() / self.scheduler.N[self.layer]}")
        return grad * mask


def _create_step_wrapper(scheduler, optimizer):
    if scheduler.args.ssam:
        _unwrapped_step = optimizer.second_step
    else:
        _unwrapped_step = optimizer.step
    def _wrapped_step():
        if scheduler.args.ssam:
            _unwrapped_step(zero_grad=True)
        else:
            _unwrapped_step()
        scheduler.reset_momentum()
        scheduler.apply_mask_to_weights()
        # scheduler.apply_mask_to_gradients()
    optimizer.step = _wrapped_step



class DSTScheduler:

    def __init__(self, model, optimizer, T_end=None, sparsity_distribution='uniform', ignore_linear_layers=False, delta=100, alpha=0.3, static_topo=False, grad_accumulation_n=1, state_dict=None, args=None):
        self.args = args
        self.dense_allocation = 1 - self.args.granet_init_sparsity if self.args.granet or self.args.gmp else 1 - self.args.sparsity
        if self.dense_allocation <= 0 or self.dense_allocation > 1:
            raise Exception('Dense allocation must be on the interval (0, 1]. Got: %f' % self.dense_allocation)

        self.model = model
        self.optimizer = optimizer

        self.W, self._linear_layers_mask, self.chain_list = get_W(model, return_linear_layers_mask=True)
        
        # if distributed these values will be populated
        self.is_dist = dist.is_initialized()
        self.world_size = dist.get_world_size() if self.is_dist else None

        

        # modify optimizer.step() function to call "reset_momentum" after
        _create_step_wrapper(self, optimizer)
            
        self.N = [torch.numel(w) for w in self.W]
        if self.args.early_stop:
            self.early_stop_signal = torch.zeros(len(self.W))
        if state_dict is not None:
            self.load_state_dict(state_dict)
            self.apply_mask_to_weights()

        else:
            self.sparsity_distribution = sparsity_distribution
            self.static_topo = static_topo
            self.grad_accumulation_n = grad_accumulation_n
            self.ignore_linear_layers = ignore_linear_layers
            self.backward_masks = None

            # define sparsity allocation
            self.S = []
            for i, (W, is_linear) in enumerate(zip(self.W, self._linear_layers_mask)):
                if self.args.EM_S:
                    self.S.append((1 - self.dense_allocation - 0.05))
                elif self.args.granet or self.args.gmp:
                    self.S.append(1 - self.dense_allocation)
                else:
                    self.S.append(1 - self.dense_allocation)
            if args.init_mode == "swi" or args.init_mode == "kaiming":
                # reset the parameters with swi
                self.reset_parameters()
                pass
            
            if self.args.history_weights:
                self.history_masks = [self.W[i].detach().clone().cpu() for i in range(len(self.W))]
            # randomly sparsify model according to S
            print('here')
            self.random_sparsify()
            print('finish')

            # scheduler keeps a log of how many times it's called. this is how it does its scheduling
            self.step = 0
            self.dst_steps = 0

            # define the actual schedule
            self.delta_T = delta
            self.alpha = alpha
            self.T_end = T_end

        # also, register backward hook so sparse elements cannot be recovered during normal training
        self.backward_hook_objects = []
        for i, w in enumerate(self.W):
            # if sparsity is 0%, skip
            if self.S[i] <= 0:
                self.backward_hook_objects.append(None)
                continue
            if getattr(w, '_has_rigl_backward_hook', False):
                print(i, w.shape)
                # print()
                raise Exception('This model already has been registered to a DSTScheduler.')
        
            self.backward_hook_objects.append(IndexMaskHook(i, self))
            w.register_hook(self.backward_hook_objects[-1])
            setattr(w, '_has_rigl_backward_hook', True)

        self.final_iter = int(self.T_end / self.delta_T)
        self.ini_iter = int(int((self.T_end/self.args.epochs) * self.args.granet_init_epoch)/ self.delta_T)
        self.total_prune_iter = self.final_iter - self.ini_iter
        assert self.grad_accumulation_n > 0 and self.grad_accumulation_n < delta
        assert self.sparsity_distribution in ('uniform', 'non-uniform')


    def state_dict(self):
        obj = {
            'dense_allocation': self.dense_allocation,
            'S': self.S,
            'N': self.N,
            'hyperparams': {
                'delta_T': self.delta_T,
                'alpha': self.alpha,
                'T_end': self.T_end,
                'ignore_linear_layers': self.ignore_linear_layers,
                'static_topo': self.static_topo,
                'sparsity_distribution': self.sparsity_distribution,
                'grad_accumulation_n': self.grad_accumulation_n,
            },
            'step': self.step,
            'dst_steps': self.dst_steps,
            'backward_masks': self.backward_masks,
            '_linear_layers_mask': self._linear_layers_mask,
        }

        return obj

    def load_state_dict(self, state_dict):
        for k, v in state_dict.items():
            if type(v) == dict:
                self.load_state_dict(v)
            setattr(self, k, v)


    @torch.no_grad()
    def random_sparsify(self):
        is_dist = dist.is_initialized()
        self.backward_masks = []
        self.record_mask = []
        if self.args.QHI:
            masks = create_QHI_sparse_scheduler(self.args.sparsity, self.W, self.args)
        for l, w in enumerate(self.W):
            print(f'-----layer:{l}-----')
            is_last = (l == len(self.W)-1)            
            # if sparsity is 0%, skip
            if self.S[l] <= 0:
                self.backward_masks.append(None)
                continue
            mask=None
            if self.args.WS:
                mask = create_ws_sparse_scheduler(self.S[l], w, self.args)
                save_dir = f"../Cannistraci-Hebb-training/mlp_and_cnn/dendritic_masks/WS/{self.args.dataset}_{self.args.regrow_method}_{self.args.remove_method}_rw_{self.args.random_rewiring}"
                save_path = os.path.join(save_dir, f"mask_{l}.npy")

                os.makedirs(save_dir, exist_ok=True)
                mask_np = mask.cpu().numpy()
                np.save(save_path, mask_np)
                print(f"mask save to {save_path}")  
            elif self.args.cross:
                mask = create_ws_cross_scheduler(self.S[l], w, self.args)
                save_dir = f"../Cannistraci-Hebb-training/mlp_and_cnn/dendritic_masks/cross/{self.args.dataset}_{self.args.regrow_method}_{self.args.remove_method}_rw_{self.args.random_rewiring}"
                save_path = os.path.join(save_dir, f"mask_{l}.npy")

                os.makedirs(save_dir, exist_ok=True)
                mask_np = mask.cpu().numpy()
                np.save(save_path, mask_np)
                print(f"mask save to {save_path}")  
            elif self.args.CWS:
                print("selected CWS")
                mask = create_cws_sparse_scheduler(self.S[l], w, self.args)
                save_dir = f"../Cannistraci-Hebb-training/mlp_and_cnn/dendritic_masks/CWS/{self.args.dataset}_{self.args.regrow_method}_{self.args.remove_method}_sigmax_{self.args.sigma_x}_sigmay_{self.args.sigma_y}_rho_{self.args.rho}_rw_{self.args.random_rewiring}"
                save_path = os.path.join(save_dir, f"mask_{l}.npy")

                os.makedirs(save_dir, exist_ok=True)
                mask_np = mask.cpu().numpy()
                np.save(save_path, mask_np)
                print(f"mask save to {save_path}") 
            elif self.args.WS1:
                print("selected WS1")
                mask = create_ws1_sparse_scheduler(self.S[l], w, self.args)
                save_dir = f"../Cannistraci-Hebb-training/mlp_and_cnn/dendritic_masks/WS1/{self.args.dataset}_{self.args.regrow_method}_{self.args.remove_method}_delta_dist_{self.args.delta_dist}_delta_{self.args.delta}_rw_{self.args.random_rewiring}"
                save_path = os.path.join(save_dir, f"mask_{l}.npy")

                os.makedirs(save_dir, exist_ok=True)
                mask_np = mask.cpu().numpy()
                np.save(save_path, mask_np)
                print(f"mask save to {save_path}")      
            elif self.args.WS2:
                print("selected WS2")
                mask = create_ws2_sparse_scheduler(self.S[l], w, self.args)
                save_dir = f"../Cannistraci-Hebb-training/mlp_and_cnn/dendritic_masks/WS2/{self.args.dataset}_{self.args.regrow_method}_{self.args.remove_method}_delta_dist_{self.args.delta_dist}_delta_{self.args.delta}_rw_{self.args.random_rewiring}"
                save_path = os.path.join(save_dir, f"mask_{l}.npy")

                os.makedirs(save_dir, exist_ok=True)
                mask_np = mask.cpu().numpy()
                np.save(save_path, mask_np)
                print(f"mask save to {save_path}")      
            elif self.args.WS3:
                print("selected WS3")
                mask = create_ws3_sparse_scheduler(self.S[l], w, self.args)
                save_dir = f"../Cannistraci-Hebb-training/mlp_and_cnn/dendritic_masks/WS3/{self.args.dataset}_{self.args.regrow_method}_{self.args.remove_method}_degree_dist_{self.args.degree_dist}_delta_dist_{self.args.delta_dist}_delta_{self.args.delta}"
                save_path = os.path.join(save_dir, f"mask_{l}.npy")

                os.makedirs(save_dir, exist_ok=True)
                mask_np = mask.cpu().numpy()
                np.save(save_path, mask_np)
                print(f"mask save to {save_path}")      
            elif self.args.BHI:
                mask = create_BHI_sparse_scheduler(self.S[l], w, self.args, is_last_layer=is_last)
                save_dir = f"../Cannistraci-Hebb-training/mlp_and_cnn/dendritic_masks/BHI/{self.args.dataset}_{self.args.regrow_method}_{self.args.remove_method}_T_{self.args.BHI_T}_gamma_{self.args.BHI_gamma}_dist_{self.args.BHI_distr}_rm_{self.args.rewire_mode}"
                save_path = os.path.join(save_dir, f"mask_{l}.npy")

                os.makedirs(save_dir, exist_ok=True)
                mask_np = mask.cpu().numpy()
                np.save(save_path, mask_np)
                print(f"mask save to {save_path}")      
            elif self.args.QHI:
                mask = masks[l]
                save_dir = f"../Cannistraci-Hebb-training/mlp_and_cnn/dendritic_masks/QHI/{self.args.dataset}_{self.args.regrow_method}_{self.args.remove_method}_T_{self.args.BHI_T}_gamma_{self.args.BHI_gamma}_dist_{self.args.BHI_distr}_rm_{self.args.rewire_mode}"
                save_path = os.path.join(save_dir, f"mask_{l}.npy")

                os.makedirs(save_dir, exist_ok=True)
                mask_np = mask.cpu().numpy()
                np.save(save_path, mask_np)
                print(f"mask save to {save_path}")  
            elif self.args.DNM:
                mask = create_dendritic_sparse_scheduler(self.S[l], w, self.args)
                mask_np = mask.cpu().numpy()
                ff = f"../Cannistraci-Hebb-training/mlp_and_cnn/dendritic_masks/degreedist_{self.args.degree_dist}_M_{self.args.M}_Mdist_{self.args.M_dist}_gamma_{self.args.gamma}_gammadist_{self.args.gamma_dist}_synapticdist_{self.args.synaptic_dist}.npy"
                np.save(ff, mask_np)
                print(f"Adjacency matrix saved at {ff}")
            elif self.args.self_correlated_sparse:
                n = self.N[l]
                number_of_links = int((1-self.S[l]) * n)

                corr_filename = f"self-correlated_sparse/{self.args.dataset}"
                if os.path.exists(corr_filename + "/corr.mat"):
                    corr = loadmat(corr_filename + "/corr.mat")["corr"]
                else:
                    dataloader, input_of_sparse_layer = load_calib_dataset(self.args, data_root="../data")

                    print("Using self correlated sparse of mlp!!!")
                    
                    for batch_idx, (data, _) in enumerate(dataloader):
                        input_of_sparse_layer[:,batch_idx*self.args.batch_size:batch_idx*self.args.batch_size + data.shape[0]] = data.reshape(data.shape[0], -1).numpy().transpose(1, 0)
                    corr = np.corrcoef(input_of_sparse_layer)
                    os.makedirs(corr_filename)
                    print("done")
                    
                    savemat(corr_filename + "/corr.mat", {"corr":corr})   ### ADAPT TO QWEN

                isnan = np.isnan(corr)
                corr[isnan] = 0

                for i in range(corr.shape[0]):
                    corr[i, i] = 0

                if self.args.dim == 1:
                    mask = update_topology_scheduler(w, corr, number_of_links)
                elif self.args.dim == 2:
                    dimension = corr.shape[0] * 2
                    expanded_dimension = np.zeros((dimension, dimension))
                    expanded_dimension[:dimension//2, :dimension//2] = corr
                    expanded_dimension[:dimension//2, dimension//2:] = corr
                    expanded_dimension[dimension//2:, :dimension//2] = corr
                    expanded_dimension[dimension//2:, dimension//2:] = corr
                    mask = update_topology_scheduler(w, expanded_dimension[:w.shape[0], :w.shape[1]], number_of_links)
                else:
                    raise NotImplementedError
            elif self.args.soft_self_correlated_sparse:
                n = self.N[l]
                number_of_links = int((1-self.S[l]) * n)

                corr_filename = f"self-correlated_sparse/{self.args.dataset}"
                if os.path.exists(corr_filename + "/corr.mat"):
                    corr = loadmat(corr_filename + "/corr.mat")["corr"]
                else:
                    dataloader, input_of_sparse_layer = load_calib_dataset(self.args, data_root="../data")

                    print("Using self correlated sparse of mlp!!!")
                    
                    for batch_idx, (data, _) in enumerate(dataloader):
                        input_of_sparse_layer[:,batch_idx*self.args.batch_size:batch_idx*self.args.batch_size + data.shape[0]] = data.reshape(data.shape[0], -1).numpy().transpose(1, 0)
                    corr = np.corrcoef(input_of_sparse_layer)
                    os.makedirs(corr_filename)
                    print("done")
                    
                    savemat(corr_filename + "/corr.mat", {"corr":corr})   ### ADAPT TO QWEN

                isnan = np.isnan(corr)
                corr[isnan] = 0

                for i in range(corr.shape[0]):
                    corr[i, i] = 0

                if self.args.dim == 1:
                    mask = update_topology_scheduler_soft(w, corr, number_of_links)
                elif self.args.dim == 2:
                    dimension = corr.shape[0] * 2
                    expanded_dimension = np.zeros((dimension, dimension))
                    expanded_dimension[:dimension//2, :dimension//2] = corr
                    expanded_dimension[:dimension//2, dimension//2:] = corr
                    expanded_dimension[dimension//2:, :dimension//2] = corr
                    expanded_dimension[dimension//2:, dimension//2:] = corr

                    mask = update_topology_scheduler_soft(w, expanded_dimension[:w.shape[0], :w.shape[1]], number_of_links)
                else:
                    raise NotImplementedError
            else:
                n = self.N[l]
                s = int(self.S[l] * n)
                perm = torch.randperm(n)
                perm = perm[:s]
                flat_mask = torch.ones(n, device=w.device)
                flat_mask[perm] = 0
                mask = torch.reshape(flat_mask, w.shape)

            if self.is_dist:
                dist.broadcast(mask, 0)
            mask = mask.bool()
            w *= mask
            self.backward_masks.append(mask)
            if self.args.itop:
                self.record_mask.append(mask)
            


    def __str__(self):
        s = 'DSTScheduler(\n'
        s += 'layers=%i,\n' % len(self.N)

        # calculate the number of non-zero elements out of the total number of elements
        N_str = '['
        S_str = '['
        sparsity_percentages = []
        total_params = 0
        total_nonzero = 0

        for N, S, mask, W, is_linear in zip(self.N, self.S, self.backward_masks, self.W, self._linear_layers_mask):
            actual_S = torch.sum(W[mask == 0] == 0).item()
            N_str += ('%i/%i, ' % (N-actual_S, N))
            sp_p = float(N-actual_S) / float(N) * 100
            S_str += '%.2f%%, ' % sp_p
            sparsity_percentages.append(sp_p)
            total_params += N
            total_nonzero += N-actual_S

        N_str = N_str[:-2] + ']'
        S_str = S_str[:-2] + ']'
        
        s += 'nonzero_params=' + N_str + ',\n'
        s += 'nonzero_percentages=' + S_str + ',\n'
        s += 'total_nonzero_params=' + ('%i/%i (%.2f%%)' % (total_nonzero, total_params, float(total_nonzero)/float(total_params)*100)) + ',\n'
        # s += 'total_CONV_nonzero_params=' + ('%i/%i (%.2f%%)' % (total_conv_nonzero, total_conv_params, float(total_conv_nonzero)/float(total_conv_params)*100)) + ',\n'
        s += 'step=' + str(self.step) + ',\n'
        s += 'num_dst_steps=' + str(self.dst_steps) + ',\n'
        s += 'ignoring_linear_layers=' + str(self.ignore_linear_layers) + ',\n'
        s += 'sparsity_distribution=' + str(self.sparsity_distribution) + ',\n'
        if self.args.WS:
            s += 'WS=True, WS_beta=' + str(self.args.ws_beta) + ',\n'
        
        if self.args.gmp:
            s += f'GMP=True, init_sparsity={self.args.granet_init_sparsity}\n'
            s += f'pruning_scheduler={self.args.pruning_scheduler}, pruning_method={self.args.pruning_method}\n'
        elif self.args.regrow_method == "fc":
            s += 'FC=True,\n'
        else:
            if self.args.granet:
                s += f'granet=True, init_sparsity={self.args.granet_init_sparsity}\n'
                s += f'pruning_scheduler={self.args.pruning_scheduler}, pruning_method={self.args.pruning_method}\n'

            if self.args.EM_S:
                s += 'EM_S=True,\n'
            elif self.args.adaptive_zeta:
                s += 'Adaptive_zeta=True,\n'
            s += 'regrow_method=' + self.args.regrow_method + ',\n'
            s += 'remove_method=' + self.args.remove_method + ',\n'
        
        if self.args.history_weights:
            s += 'history_weights=True,\n'
            
        s += 'target_sparsity=' + str(self.args.sparsity) + ',\n'
        

        return s + ')'


    @torch.no_grad()
    def reset_momentum(self):
        for w, mask, s in zip(self.W, self.backward_masks, self.S):
            # if sparsity is 0%, skip
            if s <= 0:
                continue
            # optimizer is SGD
            param_state = self.optimizer.state[w]
            if self.args.ssam:
                base_param_state = self.optimizer.base_optimizer.state[w]
            # print(param_state.keys())
            # exit()
            optimizer_state_list = ["momentum_buffer", "prev_grad", "prev_u", "e_w"]
            for optimizer_state in optimizer_state_list:
                if optimizer_state in param_state:
                    # mask the momentum matrix
                    buf = param_state[optimizer_state]
                    buf *= mask
                if self.args.ssam:
                    if optimizer_state in base_param_state:
                        buf = base_param_state[optimizer_state]
                        buf *= mask


    @torch.no_grad()
    def apply_mask_to_weights(self):
        for w, mask, s in zip(self.W, self.backward_masks, self.S):
            # if sparsity is 0%, skip
            if s <= 0:
                continue
                
            w *= mask


    @torch.no_grad()
    def apply_mask_to_gradients(self):
        for w, mask, s in zip(self.W, self.backward_masks, self.S):
            # if sparsity is 0%, skip
            if s <= 0:
                continue

            w.grad *= mask
    
    @torch.no_grad()
    def apply_mask_to_history_weights(self):
        for w,mask,s,history_mask in zip(self.W , self.backward_masks , self.S , self.history_masks):
            if s<=0:
                continue
            new_links=(w==0.) & (mask==1)
            assert 'cuda' in str(new_links.device)
            history_mask=history_mask.to(self.args.device)
            w*=mask
            w[new_links]=history_mask[new_links]


    
    def check_if_backward_hook_should_accumulate_grad(self):
        """
        Used by the backward hooks. Basically just checks how far away the next rigl step is, 
        if it's within `self.grad_accumulation_n` steps, return True.
        """

        if self.step >= self.T_end:
            return False

        steps_til_next_rigl_step = self.delta_T - (self.step % self.delta_T)
        return steps_til_next_rigl_step <= self.grad_accumulation_n


    def cosine_annealing(self):
        return self.alpha / 2 * (1 + np.cos((self.step * np.pi) / self.T_end))


    def __call__(self):
        self.step += 1
        if self.static_topo:
            return True
        
        if self.args.early_stop:
            if torch.sum(self.early_stop_signal) == len(self.W):
                # print("All layer early stopped!")
                return True
        
        if (self.step % self.delta_T) == 0 and self.step <= self.T_end: # check schedule
            self._dst_step()
            self.dst_steps += 1
            print(self)
            return False
        if (self.step % self.delta_T) == 0:
            print(self)
        return True

    def uniform_pruning(self):
        curr_prune_iter = int((self.step - self.ini_iter) / self.delta_T)
        
        if self.args.pruning_scheduler == "linear":
            curr_prune_rate = (self.args.sparsity - self.args.granet_init_sparsity)* curr_prune_iter / self.total_prune_iter + self.args.granet_init_sparsity
        elif self.args.pruning_scheduler == "granet":
            prune_decay = (1 - ((curr_prune_iter - self.ini_iter) / self.total_prune_iter)) ** 3
            curr_prune_rate = self.args.granet_init_sparsity + (self.args.sparsity - self.args.granet_init_sparsity) * (1 - prune_decay)
            
        elif self.args.pruning_scheduler == "s_shape":
            mid_prune_step = self.total_prune_iter / 2
            # S-shape pruning curve
            
            k = 6/mid_prune_step
            
            prune_rate_step_0 = 1/(1 + np.exp(-k * (- mid_prune_step)))
            prune_rate_step_final = 1/(1 + np.exp(-k * (self.total_prune_iter - mid_prune_step)))
            scale_factor = 1 / (prune_rate_step_final - prune_rate_step_0)
            
            

            
            curr_prune_rate = (self.args.sparsity - self.args.granet_init_sparsity) * ((1 / (1 + np.exp(-k * (curr_prune_iter - mid_prune_step)))- 0.5) * scale_factor + 0.5)  + self.args.granet_init_sparsity
        
        else:
            raise NotImplementedError
        
        
        print('******************************************************')
        print(f'Pruning Progress is {curr_prune_iter - self.ini_iter} / {self.total_prune_iter}')
        print('******************************************************')
        
        for l, w in enumerate(self.W):
            if self.S[l] <= 0:
                continue
            if self.args.history_weights:
                weight = self.history_masks[l].clone().to(w.device)
            else:
                if self.is_dist:
                    dist.all_reduce(w)
                    w /= self.world_size
                weight = w.clone()
            
            if self.args.pruning_method == "weight_magnitude":
                weight_abs = torch.abs(weight)
                
            elif self.args.pruning_method == "ri":
                eps = 0.00001
                weight_abs = torch.abs(weight)/torch.sum(torch.abs(weight)+ eps, dim=0) + torch.abs(weight)/torch.sum(torch.abs(weight)+ eps, dim=1).reshape(-1, 1)
                
            elif self.args.pruning_method == "MEST":
                score_grow = torch.abs(self.backward_hook_objects[l].dense_grad)
                if self.is_dist:
                    dist.all_reduce(score_grow)
                    score_grow /= self.world_size
                weight_abs = torch.abs(weight) + self.args.factor * torch.abs(score_grow)
                
            else:
                raise NotImplementedError
            
            
            weight_abs_flatten = torch.flatten(weight_abs)
            num_params_to_keep = int(len(weight_abs_flatten) * (1 - curr_prune_rate))
            threshold, _ = torch.topk(weight_abs_flatten, num_params_to_keep, sorted=True)
            acceptable_score = threshold[-1]
            self.backward_masks[l] = (weight_abs > acceptable_score).float().to(w.device)
            self.S[l] = 1 - torch.sum(self.backward_masks[l]).item() / self.N[l]
            if self.is_dist:
                dist.broadcast(self.backward_masks[l], 0)
            
    
    def non_uniform_pruning(self):
        curr_prune_iter = int((self.step - self.ini_iter) / self.delta_T)
        
        if self.args.pruning_scheduler == "linear":
            curr_prune_rate = (self.args.sparsity - self.args.granet_init_sparsity) * curr_prune_iter / self.total_prune_iter + self.args.granet_init_sparsity
        elif self.args.pruning_scheduler == "granet":
            prune_decay = (1 - ((curr_prune_iter - self.ini_iter) / self.total_prune_iter)) ** 3
            curr_prune_rate = self.args.granet_init_sparsity + (self.args.sparsity - self.args.granet_init_sparsity) * (1 - prune_decay)
            
        elif self.args.pruning_scheduler == "s_shape":
            mid_prune_step = self.total_prune_iter / 2
            # S-shape pruning curve
            k = 6/mid_prune_step
            
            prune_rate_step_0 = 1/(1 + np.exp(-k * (- mid_prune_step)))
            prune_rate_step_final = 1/(1 + np.exp(-k * (self.total_prune_iter - mid_prune_step)))
            scale_factor = 1 / (prune_rate_step_final - prune_rate_step_0)
            
            

            
            curr_prune_rate = (self.args.sparsity - self.args.granet_init_sparsity) * ((1 / (1 + np.exp(-k * (curr_prune_iter - mid_prune_step)))- 0.5) * scale_factor + 0.5)  + self.args.granet_init_sparsity
        
        else:
            raise NotImplementedError
        
        print('******************************************************')
        print(f'Pruning Progress is {curr_prune_iter - self.ini_iter} / {self.total_prune_iter}')
        print('******************************************************')

        weight_abs = []
        for l, w in enumerate(self.W):
            # print(f"Layer {l}: type of self.backward_masks[l] is {self.backward_masks[l].dtype}")
            if self.S[l] <= 0:
                continue
            if self.args.history_weights:
                weight = self.history_masks[l].clone().to(w.device)
            else:
                if self.is_dist:
                    dist.all_reduce(w)
                    w /= self.world_size
                weight = w.clone()
            
            if self.args.pruning_method == "weight_magnitude":
                weight_abs.append(torch.abs(weight))
                
            elif self.args.pruning_method == "ri":
                eps = 0.00001
                weight_abs.append(torch.abs(weight)/torch.sum(torch.abs(weight)+ eps, dim=0) + torch.abs(weight)/torch.sum(torch.abs(weight)+ eps, dim=1).reshape(-1, 1))
                
            elif self.args.pruning_method == "MEST":
                score_grow = torch.abs(self.backward_hook_objects[l].dense_grad)
                if self.is_dist:
                    dist.all_reduce(score_grow)
                    score_grow /= self.world_size
                weight_abs.append(torch.abs(weight) + self.args.factor * torch.abs(score_grow))
            else:
                raise NotImplementedError

        # Gather all scores in a single vector and normalise
        all_scores = torch.cat([torch.flatten(x) for x in weight_abs])
        num_params_to_keep = int(len(all_scores) * (1 - curr_prune_rate))

        threshold, _ = torch.topk(all_scores, num_params_to_keep, sorted=True)
        acceptable_score = threshold[-1]
        
        total_size = 0
        sparse_size = 0
        for l, w in enumerate(self.W):
            if self.S[l] <= 0:
                continue
            self.backward_masks[l] = (weight_abs[l] > acceptable_score).bool().to(w.device) # must be > to prevent acceptable_score is zero, leading to dense tensors
            if self.is_dist:
                dist.broadcast(self.backward_masks[l], 0)
        
            total_size += self.N[l]
            sparse_size += torch.sum(self.backward_masks[l]).item()
            self.S[l] = 1 - torch.sum(self.backward_masks[l]).item() / self.N[l]
            # print(f"Layer {l}: mask sparsity is {torch.sum(self.backward_masks[l]).item() / self.N[l]}")
            
        
        print('Total Model parameters:', total_size)
        print('density after pruning: {0}'.format(
            sparse_size / total_size))
        

    @torch.no_grad()
    def _dst_step(self):

        if self.args.EM_S and self.args.adaptive_zeta:
            print("EM_S and adaptive_zeta cannot be used together!")
            raise NotImplementedError
        
        if self.args.history_weights:
            for l, w in enumerate(self.W):
                if self.is_dist:
                    dist.all_reduce(w)
                    w /= self.world_size
                self.history_masks[l][self.backward_masks[l] == 1] = w.detach().clone()[self.backward_masks[l] == 1].cpu()

        
        if self.args.granet or self.args.gmp:
            if self.sparsity_distribution == "non-uniform":
                # need_to_update_sparsity
                self.non_uniform_pruning()
            elif self.sparsity_distribution == "uniform":
                self.uniform_pruning()
            else:
                raise NotImplementedError
            
            self.reset_momentum()
            if self.args.history_weights:
                self.apply_mask_to_history_weights()
            else:
                self.apply_mask_to_weights()
            self.apply_mask_to_gradients()
        
        
        if self.args.gmp:
            # Gradual Magnitude Pruning
            if self.args.chain_removal:
                self.chain_removal()
            
                self.reset_momentum()
                if self.args.history_weights:
                    self.apply_mask_to_history_weights()
                else:
                    self.apply_mask_to_weights()
                self.apply_mask_to_gradients()
            return
        
        self.link_removal()


        if self.args.chain_removal:
            self.chain_removal()
        
        self.link_regrowth()

        self.reset_momentum()

        if self.args.history_weights:
            self.apply_mask_to_history_weights()
        else:
            self.apply_mask_to_weights()
        self.apply_mask_to_gradients()
    
    def link_removal(self):

        for l, w in enumerate(self.W):
            if self.S[l] <= 0:
                continue
            
            if self.args.EM_S:
                drop_fraction = (1-self.S[l]-self.dense_allocation)/(1-self.S[l])
                
            elif self.args.adaptive_zeta:
                drop_fraction = self.cosine_annealing()
            else:
                drop_fraction = self.alpha
                
            current_mask = self.backward_masks[l]
            n_total = self.N[l]
            n_ones = torch.sum(current_mask).item()
            n_prune = int(n_ones * drop_fraction)
            n_keep = int(n_ones - n_prune)
            # print(n_keep)
            if self.args.remove_method == "weight_magnitude":
                score_drop = torch.abs(w)
                if self.is_dist:
                    dist.all_reduce(score_drop)
                    score_drop /= self.world_size
                    
                _, sorted_indices = torch.topk(score_drop.view(-1), k=n_total)
                new_values = torch.where(
                            torch.arange(n_total, device=w.device) < n_keep,
                            torch.ones_like(sorted_indices),
                            torch.zeros_like(sorted_indices))
                mask = new_values.scatter(0, sorted_indices, new_values)
                
            elif self.args.remove_method == "weight_magnitude_soft":
                score_drop = torch.abs(w)
                if self.is_dist:
                    dist.all_reduce(score_drop)
                    score_drop /= self.world_size
                T = self.args.start_T + self.step * ((self.args.end_T-self.args.start_T) / self.T_end)
                mask = torch.zeros_like(score_drop.view(-1)).to(w.device)
                flat_matrix = (score_drop.flatten())** T
                probabilities = flat_matrix / flat_matrix.sum()
                if probabilities.size(dim=0)<2**24:
                    sampled_flat_indices = torch.multinomial(probabilities, max(1, n_keep), replacement=False)
                else:
                    sampled_flat_indices = custom_multinomial(probabilities, max(1, n_keep),
                                                                    replacement=False,device=w.device)
                mask[sampled_flat_indices] = 1
                
            elif self.args.remove_method == "ri":
                score_drop = torch.abs(w)
                if self.is_dist:
                    dist.all_reduce(score_drop)
                    score_drop /= self.world_size
                eplison = 0.00001
                score_drop = torch.abs(score_drop)/torch.sum(torch.abs(score_drop)+ eplison, dim=0) + torch.abs(score_drop)/torch.sum(torch.abs(score_drop)+ eplison, dim=1).reshape(-1, 1)
                _, sorted_indices = torch.topk(score_drop.view(-1), k=n_total)
                new_values = torch.where(
                            torch.arange(n_total, device=w.device) < n_keep,
                            torch.ones_like(sorted_indices),
                            torch.zeros_like(sorted_indices))
                mask = new_values.scatter(0, sorted_indices, new_values)
            
            elif self.args.remove_method == "ri_soft":
                eplison = 0.00001
                score_drop = torch.abs(w)
                if self.is_dist:
                    dist.all_reduce(score_drop)
                    score_drop /= self.world_size
                score_drop = torch.abs(score_drop)/torch.sum(torch.abs(score_drop)+ eplison, dim=0) + torch.abs(score_drop)/torch.sum(torch.abs(score_drop)+ eplison, dim=1).reshape(-1, 1)
                T = 1 + self.step * (2 / self.T_end)
                mask = torch.zeros_like(score_drop.view(-1)).to(w.device)
                flat_matrix = (score_drop.flatten())** T
                probabilities = flat_matrix / flat_matrix.sum()
                if probabilities.size(dim=0)<2**24:
                        sampled_flat_indices = torch.multinomial(probabilities, max(1, n_keep), replacement=False)
                else:
                        sampled_flat_indices = custom_multinomial(probabilities, max(1, n_keep),
                                                                    replacement=False,device=w.device)
                mask[sampled_flat_indices] = 1
            
            elif self.args.remove_method == "MEST":
                score_drop = torch.abs(w)
                if self.is_dist:
                    dist.all_reduce(score_drop)
                    score_drop /= self.world_size
                score_grow = torch.abs(self.backward_hook_objects[l].dense_grad)
                if self.is_dist:
                    dist.all_reduce(score_grow)
                    score_grow /= self.world_size
                score_drop = score_drop + self.args.factor * torch.abs(score_grow * current_mask)
                _, sorted_indices = torch.topk(score_drop.view(-1), k=n_total)
                new_values = torch.where(
                            torch.arange(n_total, device=w.device) < n_keep,
                            torch.ones_like(sorted_indices),
                            torch.zeros_like(sorted_indices))
                mask = new_values.scatter(0, sorted_indices, new_values)
            
            else:
                raise NotImplementedError
                
            self.backward_masks[l] = torch.reshape(mask, current_mask.shape).bool().to(w.device)
 
    def chain_removal(self):
        for i in reversed(range(len(self.chain_list)-1)):
            self.backward_masks[self.chain_list[i]] = remove_unactive_links_backward(self.backward_masks[self.chain_list[i]], self.backward_masks[self.chain_list[i+1]])
        
        for i in range(1, len(self.chain_list)):
            self.backward_masks[self.chain_list[i]] = remove_unactive_links_forward(self.backward_masks[self.chain_list[i]], self.backward_masks[self.chain_list[i-1]])
            
    
    def link_regrowth(self):
        for l, w in enumerate(self.W):
            if self.S[l] <= 0:
                continue
            if self.args.EM_S:
                if self.step <= self.T_end * 0.6:
                    self.S[l] = 1-self.dense_allocation-0.05
                    n_prune = int(0.05 * self.N[l])
                elif self.step < (self.T_end - self.delta_T):
                    self.S[l] = 1-self.dense_allocation-0.025
                    n_prune = int(0.025 * self.N[l])
                else:
                    self.S[l] = 1-self.dense_allocation
                    n_prune = 0
                    print("Final sparsity: {}".format(torch.sum(self.backward_masks[l]).item()/self.N[l]))
                    continue
            else:
                n_prune = int((1-self.S[l]) * self.N[l]) - torch.sum(self.backward_masks[l]).item()


            print(f"number of pruning and regrowth in layer {l}: {n_prune}")
            if n_prune <= 0:
                continue
                
            current_mask = self.backward_masks[l].data.clone().float()
            if "ch" in self.args.regrow_method.lower():
                CH_method = self.args.regrow_method.split("_")[0]
                
                if "L3n" in self.args.regrow_method:
                    
                    DTPATHS1 = current_mask
                    
                    TDPATHS1 = DTPATHS1.transpose(1, 0)
                    DDPATHS2 = torch.matmul(DTPATHS1, TDPATHS1)
                    TTPATHS2 = torch.matmul(TDPATHS1, DTPATHS1)

                    BDDPATHS2 = DDPATHS2 != 0
                    BTTPATHS2 = TTPATHS2 != 0

                    elcl_DT = (torch.sum(DTPATHS1, dim=1) - DDPATHS2) * BDDPATHS2
                    elcl_TD = (torch.sum(TDPATHS1, dim=1) - TTPATHS2) * BTTPATHS2

                    elcl_DT[elcl_DT == 0] = 1
                    elcl_TD[elcl_TD == 0] = 1

                    elcl_DT -= 1
                    elcl_TD -= 1
                    if CH_method == "CH2":
                        elcl_DT = 1 / (elcl_DT + 1) * (DDPATHS2 + BDDPATHS2)
                        elcl_TD = 1 / (elcl_TD + 1) * (TTPATHS2 + BTTPATHS2)
                    elif CH_method == "CH3":
                        elcl_DT = 1 / (elcl_DT + 1) * BDDPATHS2
                        elcl_TD = 1 / (elcl_TD + 1) * BTTPATHS2
                    elif CH_method == "CH3.1":
                        elcl_DT = 1 / ((elcl_DT + 1) ** (1 + (elcl_DT/ (1+elcl_DT)))) * (DDPATHS2 + BDDPATHS2)
                        elcl_TD = 1 / ((elcl_TD + 1) ** (1 + (elcl_TD/ (1+elcl_TD)))) * (TTPATHS2 + BTTPATHS2)
                    

                    elcl_DT = torch.matmul(elcl_DT, DTPATHS1)
                    elcl_TD = torch.matmul(elcl_TD, TDPATHS1)

                    scores = elcl_DT + elcl_TD.T
                    scores = scores * (current_mask == 0)
                    thre = torch.sort(scores.ravel())[0][-n_prune]
                    if thre == 0:
                        print("Regrowing threshold is 0!!!")
                        scores = (scores + 0.00001)*(current_mask==0)
                
                elif "L3p" in self.args.regrow_method:
                    # CH3_L3 path-based regrowth
                    xb = np.array(current_mask.cpu())
                    x = transform_bi_to_mo(xb)
                    
                    A = csr_matrix(x)
                    ir = A.indices
                    jc = A.indptr
                    if CH_method == "CH2":
                        scores_cell = torch.tensor(np.array(CH_scores.CH_scores_new_v2(ir, jc, x.shape[0], [3], 1, 3, [2], 1))).to(w.device)
                    elif CH_method == "CH3":
                        scores_cell = torch.tensor(np.array(CH_scores.CH_scores_new_v2(ir, jc, x.shape[0], [3], 1, 3, [3], 1))).to(w.device)
                    elif CH_method == "CH3.1":
                        scores_cell = torch.tensor(np.array(CH_scores.CH_scores_new_v2(ir, jc, x.shape[0], [3], 1, 3, [5], 1))).to(w.device)
                    else:
                        raise NotImplementedError
                    scores = torch.reshape(scores_cell, x.shape)
                    scores = scores[:xb.shape[0], xb.shape[0]:]
                    
                    scores = scores * (current_mask == 0)

                    thre = torch.sort(scores.ravel())[0][-n_prune]
                    if thre == 0:
                        print("Regrowing threshold is 0!!!")
                        print(f"# of scores: {torch.sum(scores > 0)}")
                        scores = (scores + 0.00001)*(current_mask==0)
                        
            elif self.args.regrow_method == "random":
                # random regrowth
                scores = torch.rand(w.shape).to(w.device) * (current_mask == 0)
                # flatten grow scores
                thre = torch.sort(scores.ravel())[0][-n_prune]

            elif self.args.regrow_method == "gradient":
                scores = torch.abs(self.backward_hook_objects[l].dense_grad) * (current_mask == 0)
                # flatten grow scores
                thre = torch.sort(scores.ravel())[0][-n_prune]
                
            else:
                raise NotImplementedError
            
                
        
            if "soft" in self.args.regrow_method:
                mask = torch.zeros_like(scores.view(-1)).to(w.device)
                flat_matrix = scores.flatten()
                probabilities = flat_matrix / flat_matrix.sum()
                if probabilities.size(dim=0)<2**24:
                    sampled_flat_indices = torch.multinomial(probabilities, max(1, n_prune), replacement=False)
                else:
                    sampled_flat_indices = custom_multinomial(probabilities, max(1, n_prune),
                                                                    replacement=False,device=w.device)
                mask[sampled_flat_indices] = 1
            else:
                mask = torch.zeros_like(scores).to(w.device)
                mask[scores >= thre] = 1

            new_link_mask = torch.reshape(mask, current_mask.shape)
            self.backward_masks[l] = self.backward_masks[l] | (new_link_mask.bool())

            if self.args.itop:
                self.record_mask[l] = ((self.record_mask[l] == 1) | (self.backward_masks[l]))
                print("ITOP rate is : ", (torch.sum(self.record_mask[l]) / self.N[l]).item())
    
    @torch.no_grad()
    def reset_parameters(self):
        for l, w in enumerate(self.W):
            if self.args.init_mode == "swi":
                stdv = math.sqrt(2. / (((1-self.S[l]) * self.N[l]) / w.shape[1]))
            elif self.args.init_mode == "kaiming":
                stdv = math.sqrt(2 / w.shape[1])
            else:
                raise NotImplementedError
            w.data = (torch.randn(w.shape[0], w.shape[1]) * stdv).to(w.device)

def transform_bi_to_mo(xb):
    # create monopartite adjacency matrix
    x = np.zeros((xb.shape[0] + xb.shape[1], xb.shape[0] + xb.shape[1]))

    # Assign xb to the top-right block of matrix x
    x[:xb.shape[0], xb.shape[0]:] = xb

    # Assign the transpose of xb to the bottom-left block of matrix x
    x[xb.shape[0]:, :xb.shape[0]] = xb.T
    return x


def load_calib_dataset(args, data_root="../data"):
    if args.dataset == "MNIST":
        dataloader = torch.utils.data.DataLoader(
                        datasets.MNIST(data_root, train=True, download=True,
                                    transform=transforms.Compose([
                                        transforms.ToTensor()
                                    ])),
                        batch_size=args.calib_samples, shuffle=True)
        input_of_sparse_layer = np.zeros((784,60000))
    elif args.dataset == "Fashion_MNIST":
        dataloader= torch.utils.data.DataLoader(datasets.FashionMNIST(
                    root=data_root,
                    train=True,
                    transform=transforms.Compose([
                        transforms.ToTensor()
                        # transforms.Normalize((0.1307,), (0.3081,))
                    ]),
                    download=True),
                    batch_size=args.batch_size,
                    shuffle=True)
        input_of_sparse_layer = np.zeros((784,60000))
    elif args.dataset == "EMNIST":
        dataloader = torch.utils.data.DataLoader(datasets.EMNIST(
                    root=data_root,
                    train=True,
                    transform=transforms.Compose([
                        transforms.ToTensor()
                    ]),
                    download=True,
                    split='balanced'),
                    batch_size=args.batch_size,
                    shuffle=True)
        input_of_sparse_layer = np.zeros((784,50000))

    elif args.dataset == "CIFAR10":
        dataloader = torch.utils.data.DataLoader(
            datasets.CIFAR10(data_root, train=True, download=True,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                ])),
            batch_size=args.batch_size, shuffle=True)
        input_of_sparse_layer = np.zeros((3072,50000))
        

    elif args.dataset == "CIFAR100":
        dataloader = torch.utils.data.DataLoader(
            datasets.CIFAR100(data_root, train=True, download=True,
                            transform=transforms.Compose([
                                transforms.ToTensor()
                                ])),
            batch_size=args.batch_size, shuffle=True)
        input_of_sparse_layer = np.zeros((3072,50000))

    
    elif args.dataset == "FER2013":
        dataloader = torch.utils.data.DataLoader(
            datasets.FER2013(data_root, split='train', download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor()
                                ])),
            batch_size=args.batch_size, shuffle=True)
        input_of_sparse_layer = np.zeros((48 * 48, 28709))

    elif args.dataset == "SVHN":
        dataloader = torch.utils.data.DataLoader(
            datasets.SVHN(os.path.join(data_root, 'svhn'), split='train', download=True,
                          transform=transforms.Compose([
                              transforms.ToTensor()
                          ])),
            batch_size=args.batch_size, shuffle=True)
        input_of_sparse_layer = np.zeros((32 * 32 * 3, 73257))

    elif args.dataset == "tiny-imagenet-ori":
        dataloader = torch.utils.data.DataLoader(
            TinyImageNet_load(root=os.path.join(data_root, 'tiny-imagenet-200'), train=True,
                              transform=transforms.Compose([
                                  transforms.ToTensor()
                              ])),
            batch_size=args.batch_size, shuffle=True)
        input_of_sparse_layer = np.zeros((64 * 64 * 3, 100000))

    elif args.dataset == "tiny-imagenet-crop":
        dataloader = torch.utils.data.DataLoader(
            TinyImageNet_load(root=os.path.join(data_root, 'tiny-imagenet-200'), train=True,
                              transform=transforms.Compose([
                                  transforms.ToTensor()
                              ])),
            batch_size=args.batch_size, shuffle=True)
        input_of_sparse_layer = np.zeros((48 * 48 * 3, 100000))

    elif args.dataset == "tiny-imagenet-resize":
        dataloader = torch.utils.data.DataLoader(
            TinyImageNet_load(root=os.path.join(data_root, 'tiny-imagenet-200'), train=True,
                              transform=transforms.Compose([
                                  transforms.ToTensor()
                              ])),
            batch_size=args.batch_size, shuffle=True)
        input_of_sparse_layer = np.zeros((48 * 48 * 3, 100000))

    elif args.dataset == "OxfordFlowers":
        # Combine train and validation sets
        train_val = datasets.Flowers102(
            root=data_root, 
            split='train',
            download=True,
            transform=None
        )
        val_set = datasets.Flowers102(
            root=data_root, 
            split='val',
            download=True,
            transform=None
        )
        # Create combined training set (train + validation)
        train_set = torch.utils.data.ConcatDataset([train_val, val_set])

        # Apply transforms to training subset
        train_set = TransformedSubset(train_set, transform=transforms.ToTensor())
        
        # Create data loaders
        dataloader = torch.utils.data.DataLoader(
            train_set, batch_size=args.batch_size, shuffle=True
        )
        
        input_of_sparse_layer = np.zeros((64 * 64 * 3, 2040))

    elif args.dataset == "Caltech256":
        def ensure_rgb(image):
            if image.mode != 'RGB':
                return image.convert('RGB')
            return image
        # Load full dataset without initial transform
        full_set = datasets.Caltech256(
            root=data_root, 
            download=False, 
            transform=None
        )
        transform_train_rgb_caltech = transforms.Compose([
            transforms.Lambda(ensure_rgb),
            transforms.ToTensor(),
        ])
        # Create 80/20 train-test split
        train_size = int(0.8 * len(full_set))
        test_size = len(full_set) - train_size
        train_sub, test_sub = torch.utils.data.random_split(
            full_set, 
            [train_size, test_size]
        )
        
        # Apply transforms to subsets
        train_set = TransformedSubset(train_sub, transform=transform_train_rgb_caltech)

        # Create data loaders
        dataloader = torch.utils.data.DataLoader(
            train_set, batch_size=args.batch_size, shuffle=True
        )
        input_of_sparse_layer = np.zeros((64 * 64 * 3, 24485))

    elif args.dataset == "OxfordIIITPet":
        dataloader = torch.utils.data.DataLoader(
            datasets.OxfordIIITPet(data_root, split='trainval', download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor()
                                   ])),
            batch_size=args.batch_size, shuffle=True)
        input_of_sparse_layer = np.zeros((64 * 64 * 3, 3680))

    elif args.dataset == "INaturalist":
        dataloader = torch.utils.data.DataLoader(
            datasets.INaturalist(data_root, split='trainval', download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor()
                                   ])),
            batch_size=args.batch_size, shuffle=True)
        input_of_sparse_layer = np.zeros((64 * 64 * 3, 500000))

    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}")

    return dataloader, input_of_sparse_layer

    