import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass


@dataclass
class CHTConfig:
    sparsity: float
    mlp_sparsity: float
    link_update_ratio: float
    remove_method: str
    regrow_method: str
    shared_mask_sw: bool
    shared_mask_zone: bool
    zone_sz: int
    avg_remove: bool
    avg_regrow: bool
    soft: bool
    use_opt4: bool
    delta: float
    delta_max: float
    delta_d: float
    ch_method: str
    use_hidden: bool
    l3n_batch_sz: int
    evolve_es: bool
    use_manual: bool


class Conv2d_CHT(nn.Module):

    def __init__(
            self, c_in, c_out, kernel_sz,
            cht_config: CHTConfig,
            *, padding=0, stride=1):
        
        super().__init__()
        self.c_in = c_in
        self.c_out = c_out
        self.kernel_size = kernel_sz

        self.padding = padding
        self.stride = stride

        # CHT
        self.sparsity = cht_config.sparsity
        self.mlp_sparsity = cht_config.mlp_sparsity
        self.link_update_ratio = cht_config.link_update_ratio
        self.remove_method = cht_config.remove_method
        self.regrow_method = cht_config.regrow_method
        self.shared_mask_sw = cht_config.shared_mask_sw
        self.shared_mask_zone = cht_config.shared_mask_zone
        self.zone_sz = cht_config.zone_sz
        if self.zone_sz == 0:
            self.zone_sz = c_in  # The entire mask is a single zone
        elif self.zone_sz > c_in or c_in % self.zone_sz != 0:
            self.zone_sz = c_in
        self.num_zones = c_in // self.zone_sz
        self.avg_remove = cht_config.avg_remove        # Whether to remove links according to average scores between SWs: 1.3
        self.avg_regrow = cht_config.avg_regrow        # Whether to regrow links according to average scores between SWs: 1.2, 1.3

        # CHTs
        self.soft = cht_config.soft                    # Whether to use soft removal and regrowth i.e. CHTs
        self.use_opt4 = cht_config.use_opt4           # Whether to use option 4 in _get_L3n_regrow_pos method
        self.delta = cht_config.delta
        self.delta_max = cht_config.delta_max
        self.delta_d = cht_config.delta_d
        self.ch_method = cht_config.ch_method
        self.use_hidden = cht_config.use_hidden        # Whether to use hidden layer functionality
        self.l3n_batch_sz = cht_config.l3n_batch_sz    # Batch size for L3n computation
        if self.l3n_batch_sz == 0 or self.l3n_batch_sz > self.c_out:
            self.l3n_batch_sz = self.c_out
        
        self.evolve_es = cht_config.evolve_es          # Whether to enable early stop for evolve
        self.evolve_se_flag = False                    # Early stop flag for evolve
        self.use_manual = cht_config.use_manual        # Whether to use manual convolution computation

        weight = torch.empty(c_out, c_in, kernel_sz, kernel_sz)
        # torch.nn.init.kaiming_normal_(weight, mode='fan_out', nonlinearity='relu')
        # Use manual convolution computation other than pytorch's built-in implementation
        if self.sparsity != 0. and not self.shared_mask_sw:
            weight = weight.view(c_out, -1)
        # weight = weight.view(c_out, -1)  # TODO: test!->related to init??
        self.weight = nn.Parameter(weight)  # [c_out, c_in * kernel_sz * kernel_sz]

        # Initialize bias
        bias = torch.randn(c_out)
        # Use manual convolution computation other than pytorch's built-in implementation
        if self.sparsity != 0. and not self.shared_mask_sw:
            bias = bias.view(-1, 1)
        self.bias = nn.Parameter(bias)  # [c_out, 1]

        # Initialize mask
        self.mask_initialized = False
        self.input_param_initialized = False

        self.num_zeros = int(self.sparsity * self.zone_sz * kernel_sz * kernel_sz)
        self.num_active = self.zone_sz * kernel_sz * kernel_sz - self.num_zeros
        self.num_update = int(self.link_update_ratio * self.num_active)

        print(f'[DEBUG] {self.num_active = }, {self.link_update_ratio = }, {self.num_update = }')
        # print(f'[DEBUG] Use manual convolution: {self.use_manual}')
        

        # Initialize mask
        if self.sparsity == 0.:
            '''Mask is not applicable for dense model'''
            pass

        elif self.shared_mask_sw and not self.use_hidden:
            '''
            Shared mask across SW, i.e. the same mask is applied to each kernel across SWs.
            '''
            mask = torch.zeros(c_out, self.zone_sz * kernel_sz * kernel_sz, dtype=torch.bool)

            rand_values = torch.rand_like(mask, dtype=torch.float)
            _, indices = torch.topk(rand_values, self.num_active, dim=1)
            mask.scatter_(1, indices, True)

            # mask = mask.view(c_out, -1)  # [c_out, zone_sz * kernel_sz * kernel_sz]
            mask = mask.view(c_out, self.zone_sz, kernel_sz, kernel_sz)
            self.register_buffer('mask', mask)
            self._check_mask(self.mask, self.num_zeros)

            self.mask_initialized = True
        
        else:
            '''
            Different mask for each SW, though they can be
            shared across zones of the same kernel (shared_mask_zone=True).

            As the number of SW is not fixed and determined by the input shape,
            the mask is initialized lazily.
            '''
            pass
        

    def forward(self, x):

        batch_sz, c_in, h, w = x.shape

        # TODO: test!
        # # Calculate output dimensions, out_h * out_w = num_sw
        # if not self.input_param_initialized:
        #     out_h = (h + 2 * self.padding - self.kernel_size) // self.stride + 1
        #     out_w = (w + 2 * self.padding - self.kernel_size) // self.stride + 1
        #     self.out_h = out_h
        #     self.out_w = out_w
        #     self.in_h = h
        #     self.in_w = w
        #     # Calculate the conv with unfolded input
        #     x_unf = self._unfold_input(x)  # [batch_sz, c_in * K * K, num_sw]
        #     self.num_sw = x_unf.shape[-1]
        #     self.input_param_initialized = True


        # if self.sparsity == 0.:
        #     x_unf = self._unfold_input(x)  # [batch_sz, c_in * K * K, num_sw]
        #     # out = self.weight @ x_unf  # [batch_sz, c_out, out_h*out_w]
        #     out = self.weight.view(self.c_out, -1) @ x_unf  # [batch_sz, c_out, out_h*out_w]
        #     out += self.bias.unsqueeze(1)  # [batch_sz, c_out, out_h*out_w]

        # output = out.view(batch_sz, self.c_out, self.out_h, self.out_w)  # [batch_sz, c_out, out_h, out_w]
        # return output


        if not self.input_param_initialized:
            # Calculate output dimensions, out_h * out_w = num_sw
            out_h = (h + 2 * self.padding - self.kernel_size) // self.stride + 1
            out_w = (w + 2 * self.padding - self.kernel_size) // self.stride + 1
            self.out_h = out_h
            self.out_w = out_w
            self.in_h = h
            self.in_w = w
            # Calculate the conv with unfolded input
            x_unf = self._unfold_input(x)  # [batch_sz, c_in * K * K, num_sw]
            self.num_sw = x_unf.shape[-1]
            self.input_param_initialized = True
        
        if self.sparsity == 0.:
            if self.use_manual:
                x_unf = self._unfold_input(x)  # [batch_sz, c_in * K * K, num_sw]
                # weight_flatten = self.weight.flatten(1)  # [c_out, c_in * K * K]
                # out = weight_flatten @ x_unf  # [batch_sz, c_out, out_h*out_w]
                out = self.weight @ x_unf  # [batch_sz, c_out, out_h*out_w]
                out += self.bias.unsqueeze(1)  # [batch_sz, c_out, out_h*out_w]
            else:
                out = F.conv2d(x, self.weight, self.bias, stride=self.stride, padding=self.padding)

        elif self.shared_mask_sw:
            if not self.mask_initialized:
                # Initialize the hidden mask
                assert self.use_hidden
                hidden_mask = torch.zeros(  # [c_out, num_sw, zone_sz * K * K]
                    self.c_out,
                    self.num_sw,
                    self.zone_sz * self.kernel_size * self.kernel_size,
                    dtype=torch.bool)
                rand_values = torch.rand_like(hidden_mask, dtype=torch.float)
                _, indices = torch.topk(rand_values, self.num_active, dim=2)
                hidden_mask.scatter_(2, indices, True)
                
                hidden_mask = hidden_mask.to(self.weight.device)
                self.register_buffer('hidden_mask', hidden_mask)
                self._check_mask(self.hidden_mask, self.num_zeros)

                # Register the observed mask
                mask = self._hidden_to_mask()
                self.register_buffer('mask', mask)
                self._check_mask(self.mask, self.num_zeros)

                self.mask_initialized = True

            repeated_mask = self.mask.repeat(1, self.num_zones, 1, 1)
            masked_weight = self.weight.masked_fill(~repeated_mask, 0.)

            if self.use_manual:
                x_unf = self._unfold_input(x)  # [batch_sz, c_in * K * K, num_sw]
                masked_weight = masked_weight.flatten(1)  # [c_out, c_in * K * K]
                out = masked_weight @ x_unf  # [batch_sz, c_out, out_h*out_w]
                out += self.bias.unsqueeze(1)  # [batch_sz, c_out, out_h*out_w]
            else:
                out = F.conv2d(x, masked_weight, self.bias, stride=self.stride, padding=self.padding)

        elif self.shared_mask_zone:  # shared_mask_sw = False, shared_mask_zone = True
            if not self.mask_initialized:  # Lazy initialization
                mask = torch.zeros(  # [c_out, num_sw, zone_sz * K * K]
                    self.c_out, self.num_sw,
                    self.zone_sz * self.kernel_size * self.kernel_size,
                    dtype=torch.bool)
                rand_values = torch.rand_like(mask, dtype=torch.float)
                _, indices = torch.topk(rand_values, self.num_active, dim=2)
                mask.scatter_(2, indices, True)
                
                mask = mask.to(self.weight.device)
                self.register_buffer('mask', mask)
                self._check_mask(self.mask, self.num_zeros)

                self.mask_initialized = True

            # weight: [c_out, c_in * K * K]
            # mask: [c_out, num_sw, zone_sz * K * K]
            repeated_mask = self.mask.repeat(1, 1, self.num_zones)  # [c_out, num_sw, c_in * K * K]
            masked_weight = self.weight.unsqueeze(1) * repeated_mask  # [c_out, num_sw, c_in * K * K]
            x_unf = self._unfold_input(x)
            out = torch.einsum('cnk,bkn->bcn', masked_weight, x_unf)
            out += self.bias  # [batch_sz, c_out, num_sw]

        else:  # shared_mask_sw = False, shared_mask_zone = False
            # TODO: introduce zone after optimized
            if not self.mask_initialized:  # Lazy initialization
                w_shape = self.weight.shape
                mask = torch.zeros(w_shape[0], self.num_sw, *w_shape[1:], dtype=torch.bool)  # [c_out, num_sw, c_in, kernel_sz, kernel_sz]
                mask = mask.flatten(2)  # [c_out, num_sw, c_in * kernel_sz * kernel_sz]

                # Initialize the active positions
                rand_values = torch.rand_like(mask, dtype=torch.float)
                _, indices = torch.topk(rand_values, self.num_active, dim=2)
                mask.scatter_(2, indices, True)
                
                mask = mask.to(self.weight.device)
                self.register_buffer('mask', mask)

                self.mask_initialized = True
            
            # weight: [c_out, c_in * K * K]
            masked_weight = self.weight.unsqueeze(1) * self.mask  # [c_out, num_sw, c_in * K * K]
            x_unf = self._unfold_input(x)
            out = torch.einsum('cnk,bkn->bcn', masked_weight, x_unf)
            out += self.bias  # [batch_sz, c_out, num_sw]
        
        output = out.view(batch_sz, self.c_out, self.out_h, self.out_w)  # [batch_sz, c_out, out_h, out_w]
        return output


    @torch.no_grad()
    def evolve(self):
        if self.sparsity == 0. or self.link_update_ratio == 0.:
            return None, None, None

        # Early stop check
        if self.evolve_es and self.evolve_se_flag and self.regrow_method == 'L3n':
            return None, None, None

        remove_pos = self._remove()

        regrow_pos = self._regrow()
        
        # Calculate the cancellation ratio
        overlap = (remove_pos & regrow_pos).sum().item()
        total_remove = remove_pos.sum().item()
        cancellation_ratio = overlap / total_remove

        # Early stop logic
        if self.evolve_es and cancellation_ratio > 0.995:
            self.evolve_se_flag = True

        if self.delta < self.delta_max:
            self.delta += self.delta_d

        if self.shared_mask_sw and self.use_hidden:
            self.mask = self._hidden_to_mask()
        
        # Calculate mask convergence scores
        min_score, mean_score = self.calculate_mask_convergence()
        
        return cancellation_ratio, min_score, mean_score


    def _remove(self):
        if not self.use_hidden:
            remove_pos = torch.zeros_like(self.mask)  # dtype: bool
        else:
            remove_pos = torch.zeros_like(self.hidden_mask)  # dtype: bool

        if self.shared_mask_sw:
            mask_flatten = self.mask.flatten(1)  # [c_out, zone_sz * K * K]
            weight_flatten = self.weight.flatten(1)  # [c_out, c_in * K * K]
            remove_pos_flatten = remove_pos.flatten(1)  # [c_out, zone_sz * K * K]

            if not self.shared_mask_zone:
                match self.remove_method:
                    case 'rand':
                        rand_values = torch.rand_like(mask_flatten, dtype=torch.float)
                        rand_values.masked_fill_(~mask_flatten, 0.)  # Omit inactive positions
                        _, indices = torch.topk(rand_values, k=self.num_update, dim=-1)
                        remove_pos_flatten.scatter_(1, indices, True)

                    case 'wm':
                        masked_weight = weight_flatten.masked_fill(~mask_flatten, torch.inf)  # Omit the inactive positions
                        _, indices = torch.topk(masked_weight.abs(), k=self.num_update, largest=False)
                        remove_pos_flatten.scatter_(1, indices, True)

                    case _:
                        raise NotImplementedError
                
            else:
                if not self.use_hidden:
                    match self.remove_method:
                        case 'wm':
                            avg_weight = self.weight.view(
                                self.c_out,
                                self.num_zones,
                                self.zone_sz * self.kernel_size * self.kernel_size
                            )
                            avg_weight = avg_weight.mean(dim=1)  # [c_out, zone_sz * K * K]
                            masked_weight = avg_weight.masked_fill(~mask_flatten, torch.inf)  # Omit the inactive positions
                            _, indices = torch.topk(masked_weight.abs(), k=self.num_update, largest=False)
                            remove_pos_flatten.scatter_(1, indices, True)
                        
                        case _:
                            raise NotImplementedError
                else:  # use_hidden = True
                    match self.remove_method:
                        case 'wm':
                            avg_weight = self.weight.view(
                                self.c_out,
                                self.num_zones,
                                self.zone_sz * self.kernel_size * self.kernel_size
                            )
                            # Average across zones
                            avg_weight = avg_weight.mean(dim=1, keepdim=True)  # [c_out, 1, zone_sz * K * K]
                            # Average across SWs
                            scores = avg_weight.masked_fill(~self.hidden_mask, 0.)
                            scores = scores.mean(dim=1, keepdim=True)  # [c_out, 1, zone_sz * K * K]

                            scores = scores.masked_fill(~self.hidden_mask, torch.inf)  # Omit the inactive positions
                            _, indices = torch.topk(scores.abs(), k=self.num_update, largest=False)
                            remove_pos_flatten.scatter_(2, indices, True)
                        
                        case _:
                            raise NotImplementedError
            
            remove_pos = remove_pos_flatten.view(remove_pos.shape)

        elif self.shared_mask_zone:  # shared_mask_sw = False, shared_mask_zone = True
            # weight: [c_out, c_in * K * K]
            # mask: [c_out, num_sw, zone_sz * K * K]
            match self.remove_method:
                case 'rand':
                    rand_values = torch.rand_like(self.mask, dtype=torch.float)
                    rand_values.masked_fill_(~self.mask, 0.)  # Omit inactive positions
                    _, indices = torch.topk(rand_values, k=self.num_update, dim=-1)
                    remove_pos.scatter_(2, indices, True)

                case 'wm':
                    # Compute the average weight across input channels
                    avg_weight = self.weight.view(
                        self.c_out,
                        self.num_zones,
                        self.zone_sz * self.kernel_size * self.kernel_size)
                    avg_weight = avg_weight.mean(dim=1, keepdim=True)  # [c_out, 1, zone_sz * K * K]

                    if not self.avg_remove:
                        scores = avg_weight
                    else:
                        scores = avg_weight.masked_fill(~self.mask, 0.).mean(dim=1, keepdim=True)  # [c_out, 1, zone_sz * K * K]
                                        
                    if not self.soft:
                        scores = scores.masked_fill(~self.mask, torch.inf)   # Omit the inactive positions
                        _, indices = torch.topk(scores.abs(), k=self.num_update, largest=False)
                        remove_pos.scatter_(2, indices, True)
                    else:
                        scores = scores.masked_fill(~self.mask, 0.)  # Omit the inactive positions by set the prob to 0
                        scores = scores.abs().flatten(0, -2)  # [c_out * num_sw, zone_sz * K * K]
                        
                        exp = self.delta / (1 - self.delta)
                        scores **= exp

                        indices = torch.multinomial(scores, self.num_active - self.num_update, replacement=False)
                        indices = indices.view(self.c_out, self.num_sw, -1)
                        remove_pos.scatter_(2, indices, True)
                        remove_pos = ~remove_pos
                        remove_pos[~self.mask] = False  # Remove the current inactive positions from the record

                case _:
                    raise NotImplementedError

        else:  # shared_mask_sw = False, shared_mask_zone = False
            # weight: [c_out, c_in * K * K]
            # mask: [c_out, num_sw, c_in * K * K]

            match self.remove_method:  # TODO: introduce zone after optimized, rand
                case 'wm':
                    if not self.avg_remove:
                        scores = self.weight
                    else:
                        repeated_weight = self.weight.unsqueeze(1).repeat(1, self.num_sw, 1)
                        scores = repeated_weight.masked_fill(~self.mask, 0.).mean(dim=1)  # [c_out, c_in * kernel_sz * kernel_sz]

                    scores = scores.unsqueeze(1).repeat(1, self.num_sw, 1)  # To apply the mask on it
                    if not self.soft:
                        scores.masked_fill_(~self.mask, torch.inf)   # Omit the inactive positions
                        _, indices = torch.topk(scores.abs(), k=self.num_update, largest=False)
                        remove_pos.scatter_(2, indices, True)
                    else:
                        scores.masked_fill_(~self.mask, 0.)  # Omit the inactive positions by set the prob to 0
                        scores = scores.abs().flatten(0, -2)  # [c_out * num_sw, c_in * K * K]

                        exp = self.delta / (1 - self.delta)
                        scores **= exp

                        indices = torch.multinomial(scores, self.num_active - self.num_update, replacement=False)
                        indices = indices.view(self.c_out, self.num_sw, -1)
                        remove_pos.scatter_(2, indices, True)
                        remove_pos = ~remove_pos
                        remove_pos[~self.mask] = False  # Remove the current inactive positions from the record

                case _:
                    raise NotImplementedError

        # Update mask
        if not self.use_hidden:
            self.mask.masked_fill_(remove_pos, False)
            self._check_mask(self.mask, self.num_zeros + self.num_update)  # TODO: zone_sz and not shared_mask_zone
        else:
            self.hidden_mask.masked_fill_(remove_pos, False)
            self._check_mask(self.hidden_mask, self.num_zeros + self.num_update)

        return remove_pos
    
    
    def _regrow(self):
        if not self.use_hidden:
            regrow_pos = torch.zeros_like(self.mask)  # dtype: bool
        else:
            regrow_pos = torch.zeros_like(self.hidden_mask)  # dtype: bool

        if self.shared_mask_sw:
            mask_flatten = self.mask.flatten(1)  # [c_out, zone_sz * K * K]
            regrow_pos_flatten = regrow_pos.flatten(1)  # [c_out, zone_sz * K * K]

            if not self.shared_mask_zone:
                match self.regrow_method:
                    case 'rand':
                        rand_values = torch.rand_like(mask_flatten, dtype=torch.float)
                        rand_values.masked_fill_(mask_flatten, 0.)  # Omit active positions
                        _, indices = torch.topk(rand_values, k=self.num_update, dim=-1)
                        regrow_pos_flatten.scatter_(1, indices, True)

                    case _:
                        raise NotImplementedError

            else:
                if not self.use_hidden:
                    match self.regrow_method:
                        case 'L3n':  # Regrow with position average
                            mask_repeated = mask_flatten.unsqueeze(1).repeat(1, self.num_sw, 1)
                            
                            # regrow_pos_extended = self._get_L3n_regrow_pos_optimized(mask_repeated)  # [c_out, num_sw, zone_sz * K * K]
                            regrow_pos_extended = []
                            for start in range(0, self.c_out, self.l3n_batch_sz):
                                end = start + self.l3n_batch_sz
                                regrow_pos_part = self._get_L3n_regrow_pos_optimized(mask_repeated[start:end])
                                regrow_pos_extended.append(regrow_pos_part)
                            regrow_pos_extended = torch.cat(regrow_pos_extended, dim=0)  # [c_out, num_sw, zone_sz * K * K]

                            regrow_pos_avg = regrow_pos_extended.float().mean(dim=1)  # [c_out, zone_sz * K * K]
                            _, indices = torch.topk(regrow_pos_avg, k=self.num_update, dim=1)
                            regrow_pos_flatten.scatter_(1, indices, True)
                        
                        case _:
                            raise NotImplementedError
                
                else:  # use_hidden = True
                    match self.regrow_method:
                        case 'L3n':
                            regrow_pos_flatten = self._get_L3n_regrow_pos_optimized(self.hidden_mask)
                        
                        case _:
                            raise NotImplementedError
                
            regrow_pos = regrow_pos_flatten.view(regrow_pos.shape)

        elif self.shared_mask_zone:  # shared_mask_sw = False, shared_mask_zone = True
            # weight: [c_out, c_in * K * K]
            # mask: [c_out, num_sw, zone_sz * K * K]
            match self.regrow_method:
                case 'rand':
                    rand_values = torch.rand_like(self.mask, dtype=torch.float)
                    rand_values.masked_fill_(self.mask, 0.)  # Omit active positions
                    _, indices = torch.topk(rand_values, k=self.num_update, dim=-1)
                    regrow_pos.scatter_(2, indices, True)
                
                case 'L3n':
                    regrow_pos = self._get_L3n_regrow_pos_optimized(self.mask)

                case _:
                    raise NotImplementedError

        else:  # shared_mask_sw = False, shared_mask_zone = False
            # TODO: introduce zone after optimized
            # weight: [c_out, c_in * K * K]
            # mask: [c_out, num_sw, c_in * K * K]
            match self.regrow_method:
                case 'rand':
                    rand_values = torch.rand_like(self.mask, dtype=torch.float)
                    rand_values.masked_fill_(self.mask, 0.)  # Omit active positions
                    _, indices = torch.topk(rand_values, k=self.num_update, dim=-1)
                    regrow_pos.scatter_(2, indices, True)

                case 'L3n':  # TODO: introduce zone
                    regrow_pos = self._get_L3n_regrow_pos_optimized(self.mask)

                case _:
                    raise NotImplementedError

        # Update mask
        if not self.use_hidden:
            self.mask.masked_fill_(regrow_pos, True)
            self._check_mask(self.mask, self.num_zeros)
        else:
            self.hidden_mask.masked_fill_(regrow_pos, True)
            self._check_mask(self.hidden_mask, self.num_zeros)
            self.mask = self._hidden_to_mask()
            self._check_mask(self.mask, self.num_zeros)

        return regrow_pos
    

    # def _get_L3n_regrow_pos(self, mask_included):
    #     # mask_included: [c_out, num_sw, c_in_i * K * K]
    #     regrow_pos = torch.zeros_like(mask_included, dtype=torch.bool)

    #     # Get the original indices on the input for each sliding window
    #     ph, pw = self.in_h + 2 * self.padding, self.in_w + 2 * self.padding
    #     sw_indices = torch.arange(self.zone_sz * ph * pw).float()
    #     sw_indices = sw_indices.view(self.zone_sz, ph, pw)
    #     sw_indices = F.unfold(sw_indices, kernel_size=self.kernel_size, stride=self.stride)  # [c_in_i * K * K, num_sw]
    #     sw_indices.t_()  # [num_sw, c_in_i * K * K]
    #     sw_indices = sw_indices.long()

    #     # [c_out, num_sw, c_in_i * K * K]
    #     # Record scores on all kernels in this tensor
    #     scores_all_k = torch.empty_like(mask_included, dtype=torch.float)

    #     for c_out_idx in range(self.c_out):  # for each output channel
    #         # Build adjacency matrix
    #         am = torch.zeros(self.num_sw, self.zone_sz * ph * pw, dtype=torch.bool).to(self.weight.device)
    #         j_indices = torch.arange(self.num_sw).unsqueeze(1)  # [num_sw, 1]
    #         am[j_indices, sw_indices] = mask_included[c_out_idx]

    #         # [num_sw, c_in_i * pH * pW], where pH and pW means padded H and W
    #         # In am, each row corresponds to a sliding window,
    #         # and each column corresponds to a position in the input feature map.
    #         am = am.float()

    #         DTPATHS1 = am
    #         TDPATHS1 = DTPATHS1.transpose(1, 0)
    #         DDPATHS2 = torch.matmul(DTPATHS1, TDPATHS1)
    #         TTPATHS2 = torch.matmul(TDPATHS1, DTPATHS1)

    #         BDDPATHS2 = DDPATHS2 != 0
    #         BTTPATHS2 = TTPATHS2 != 0
    #         elcl_DT = (torch.sum(DTPATHS1, dim=1) - DDPATHS2) * BDDPATHS2
    #         elcl_TD = (torch.sum(TDPATHS1, dim=1) - TTPATHS2) * BTTPATHS2

    #         elcl_DT[elcl_DT == 0] = 1
    #         elcl_TD[elcl_TD == 0] = 1
    #         elcl_DT -= 1
    #         elcl_TD -= 1
    #         # CH3 branch in original code
    #         if self.ch_method == 'CH3':
    #             elcl_DT = 1 / (elcl_DT + 1) * BDDPATHS2  # [c_out, num_sw, num_sw]
    #             elcl_TD = 1 / (elcl_TD + 1) * BTTPATHS2  # [c_out, zone_sz * ph * pw, zone_sz * ph * pw]
    #         elif self.ch_method == 'CH2':
    #             elcl_DT = 1 / (elcl_DT + 1) * (DDPATHS2 + BDDPATHS2)
    #             elcl_TD = 1 / (elcl_TD + 1) * (TTPATHS2 + BTTPATHS2)
    #         elif self.ch_method == 'CH3.1':
    #             elcl_DT = 1 / ((elcl_DT + 1) ** (1 + elcl_DT / (elcl_DT + 1))) * (DDPATHS2 + BDDPATHS2)
    #             elcl_TD = 1 / ((elcl_TD + 1) ** (1 + elcl_TD / (elcl_TD + 1))) * (TTPATHS2 + BTTPATHS2)
    #         else:
    #             raise NotImplementedError

    #         elcl_DT = torch.matmul(elcl_DT, DTPATHS1)
    #         elcl_TD = torch.matmul(elcl_TD, TDPATHS1)

    #         scores_one_k = elcl_DT + elcl_TD.T  # [num_sw, c_in * pH * pW]

            
    #         # Map the scores to the positions on the mask, via indices on the reverse direction
    #         scores_one_k = scores_one_k[j_indices, sw_indices]  # [num_sw, K * K]
    #         scores_all_k[c_out_idx] = scores_one_k
        
    #     if not self.avg_regrow:
    #         scores = scores_all_k  # [c_out, num_sw, c_in_i * K * K]
    #     else:
    #         if self.use_opt4:
    #             # Option 4: replace the scores on the active positions with the highest scores of the inactive positions
    #             highest_score = torch.max(scores_all_k[~mask_included])
    #             scores_all_k[mask_included] = highest_score
    #             scores = scores_all_k.mean(dim=1, keepdim=True)
    #         else:
    #             # Original approach: average scores of inactive positions
    #             zeros_in_each_col = (~mask_included).sum(dim=1)  # Each col corresponds to a neuron
    #             avg_scores = scores_all_k.masked_fill(mask_included, 0)
    #             avg_scores = avg_scores.sum(dim=1) / (zeros_in_each_col + 1e-6)
    #             scores = avg_scores.unsqueeze(1)  # [c_out, 1, c_in_i * K * K]

    #     if not self.soft:
    #         scores = scores.masked_fill(mask_included, -1)  # Omit active positions
    #         _, top_indices = torch.topk(scores, k=self.num_update, dim=-1)
    #         regrow_pos.scatter_(2, top_indices, True)
    #     else:
    #         exp = self.delta / (1 - self.delta)
    #         scores **= exp

    #         scores += 1e-6  # To avoid the case that all inactive scores are 0
    #         scores = scores.masked_fill(mask_included, 0.)
    #         scores = scores.flatten(0, -2)  # [c_out * num_sw, K * K]
    #         indices = torch.multinomial(scores, self.num_update, replacement=False)
    #         indices = indices.view(self.c_out, self.num_sw, -1)
    #         regrow_pos.scatter_(2, indices, True)

    #     return regrow_pos


    def _get_L3n_regrow_pos_optimized(self, mask_included):
        # mask_included: [l3n_bs, num_sw, zone_sz * K * K]
        regrow_pos = torch.zeros_like(mask_included, dtype=torch.bool)

        # Get the original indices on the input for each sliding window
        ph, pw = self.in_h + 2 * self.padding, self.in_w + 2 * self.padding
        sw_indices = torch.arange(self.zone_sz * ph * pw, device=mask_included.device).float()
        sw_indices = sw_indices.view(self.zone_sz, ph, pw)
        sw_indices = F.unfold(sw_indices, kernel_size=self.kernel_size, stride=self.stride)  # [zone_sz * K * K, num_sw]
        sw_indices = sw_indices.t().long()  # [num_sw, zone_sz * K * K]

        # Build adjacency matrices for all output channels at once
        # [l3n_bs, num_sw, zone_sz * ph * pw]
        am = torch.zeros(self.l3n_batch_sz, self.num_sw, self.zone_sz * ph * pw, 
                        dtype=torch.bool, device=self.weight.device)
        
        # Expand indices for broadcasting
        c_out_indices = torch.arange(self.l3n_batch_sz, device=mask_included.device).unsqueeze(1).unsqueeze(2)  # [c_out, 1, 1]
        j_indices = torch.arange(self.num_sw, device=mask_included.device).unsqueeze(0).unsqueeze(2)     # [1, num_sw, 1]
        sw_indices_expanded = sw_indices.unsqueeze(0)  # [1, num_sw, zone_sz * K * K]
        
        am[c_out_indices, j_indices, sw_indices_expanded] = mask_included
        
        # Vectorized computation for all channels
        DTPATHS1 = am.half()  # [l3n_bs, num_sw, zone_sz * ph * pw]
        TDPATHS1 = DTPATHS1.transpose(-2, -1)  # [l3n_bs, zone_sz * ph * pw, num_sw]
        DDPATHS2 = torch.bmm(DTPATHS1, TDPATHS1)  # [l3n_bs, num_sw, num_sw]
        TTPATHS2 = torch.bmm(TDPATHS1, DTPATHS1)  # [l3n_bs, zone_sz * ph * pw, zone_sz * ph * pw]
        BDDPATHS2 = (DDPATHS2 != 0)  # [l3n_bs, num_sw, num_sw]
        BTTPATHS2 = (TTPATHS2 != 0)  # [l3n_bs, zone_sz * ph * pw, zone_sz * ph * pw]
        
        sum_DTPATHS1 = torch.sum(DTPATHS1, dim=-1)  # [l3n_bs, num_sw]
        sum_TDPATHS1 = torch.sum(TDPATHS1, dim=-1)  # [l3n_bs, zone_sz * ph * pw]
        sum_DTPATHS1 = sum_DTPATHS1.unsqueeze(1)  # [l3n_bs, 1, num_sw]
        sum_TDPATHS1 = sum_TDPATHS1.unsqueeze(1)  # [l3n_bs, 1, zone_sz * ph * pw]

        # elcl_DT = (sum_DTPATHS1 - DDPATHS2) * BDDPATHS2  # [l3n_bs, num_sw, num_sw]
        # elcl_TD = (sum_TDPATHS1 - TTPATHS2) * BTTPATHS2  # [l3n_bs, zone_sz * ph * pw, zone_sz * ph * pw]

        elcl_DT = sum_DTPATHS1 - DDPATHS2  # [l3n_bs, num_sw, num_sw]
        elcl_DT *= BDDPATHS2  # [l3n_bs, num_sw, num_sw]
        elcl_TD = sum_TDPATHS1 - TTPATHS2  # [l3n_bs, zone_sz * ph * pw, zone_sz * ph * pw]
        elcl_TD *= BTTPATHS2  # [l3n_bs, zone_sz * ph * pw, zone_sz * ph * pw]

        # elcl_DT = torch.where(elcl_DT == 0, torch.ones_like(elcl_DT), elcl_DT)
        # elcl_TD = torch.where(elcl_TD == 0, torch.ones_like(elcl_TD), elcl_TD)
        elcl_DT[elcl_DT == 0] = 1.
        elcl_TD[elcl_TD == 0] = 1.
        elcl_DT -= 1
        elcl_TD -= 1
        if self.ch_method == 'CH3':
            # elcl_DT = 1 / (elcl_DT + 1) * BDDPATHS2  # [l3n_bs, num_sw, num_sw]
            # elcl_TD = 1 / (elcl_TD + 1) * BTTPATHS2  # [l3n_bs, zone_sz * ph * pw, zone_sz * ph * pw]
            elcl_DT += 1
            elcl_DT.reciprocal_()
            elcl_DT *= BDDPATHS2  # [l3n_bs, num_sw, num_sw]

            elcl_TD += 1
            elcl_TD.reciprocal_()
            elcl_TD *= BTTPATHS2  # [l3n_bs, zone_sz * ph * pw, zone_sz * ph * pw]

        elif self.ch_method == 'CH2':
            # elcl_DT = 1 / (elcl_DT + 1) * (DDPATHS2 + BDDPATHS2)
            # elcl_TD = 1 / (elcl_TD + 1) * (TTPATHS2 + BTTPATHS2)
            elcl_DT += 1
            elcl_DT.reciprocal_()
            elcl_DT *= (DDPATHS2 + BDDPATHS2)

            elcl_TD += 1
            elcl_TD.reciprocal_()
            elcl_TD *= (TTPATHS2 + BTTPATHS2)

        elif self.ch_method == 'CH3.1':
            # elcl_DT = 1 / ((elcl_DT + 1) ** (1 + elcl_DT / (elcl_DT + 1))) * (DDPATHS2 + BDDPATHS2)
            # elcl_TD = 1 / ((elcl_TD + 1) ** (1 + elcl_TD / (elcl_TD + 1))) * (TTPATHS2 + BTTPATHS2)
            elcl_DT.add_(1)
            elcl_DT.pow_(2 - 1 / elcl_DT)
            # elcl_DT_r = elcl_DT.reciprocal()
            # elcl_DT_r.neg_()
            # elcl_DT_r += 2
            # elcl_DT.pow_(elcl_DT_r)
            elcl_DT.reciprocal_()
            elcl_DT *= (DDPATHS2 + BDDPATHS2)

            elcl_TD.add_(1)
            elcl_TD.pow_(2 - 1 / elcl_TD)
            elcl_TD.reciprocal_()
            elcl_TD *= (TTPATHS2 + BTTPATHS2)

        else:
            raise NotImplementedError
        
        elcl_DT_result = torch.bmm(elcl_DT, DTPATHS1)  # [l3n_bs, num_sw, zone_sz * ph * pw]
        elcl_TD_result = torch.bmm(elcl_TD, TDPATHS1)  # [l3n_bs, zone_sz * ph * pw, num_sw]
        scores_one_k = elcl_DT_result + elcl_TD_result.transpose(-2, -1)  # [l3n_bs, num_sw, zone_sz * ph * pw]

        # Map the scores to the positions on the mask
        # c_out_indices = torch.arange(self.c_out, device=mask_included.device).unsqueeze(1).unsqueeze(2)  # [c_out, 1, 1]
        # j_indices = torch.arange(self.num_sw, device=mask_included.device).unsqueeze(0).unsqueeze(2)     # [1, num_sw, 1]
        scores_all_k = scores_one_k[c_out_indices, j_indices, sw_indices]  # [l3n_bs, num_sw, zone_sz * K * K]
    
        if not self.avg_regrow:
            scores = scores_all_k  # [l3n_bs, num_sw, c_in_i * K * K]
        else:
            if self.use_opt4:
                # Option 4: replace the scores on the active positions with the highest scores of the inactive positions
                highest_score = torch.max(scores_all_k[~mask_included])
                scores_all_k[mask_included] = highest_score
                scores = scores_all_k.mean(dim=1, keepdim=True)
            else:
                # Original approach: average scores of inactive positions
                zeros_in_each_col = (~mask_included).sum(dim=1)  # Each col corresponds to a neuron
                avg_scores = scores_all_k.masked_fill(mask_included, 0)
                avg_scores = avg_scores.sum(dim=1) / (zeros_in_each_col + 1e-6)
                scores = avg_scores.unsqueeze(1)  # [l3n_bs, 1, c_in_i * K * K]

        if not self.soft:
            scores = scores.masked_fill(mask_included, -1)  # Omit active positions
            _, top_indices = torch.topk(scores, k=self.num_update, dim=-1)
            regrow_pos.scatter_(2, top_indices, True)
        else:
            exp = self.delta / (1 - self.delta)
            scores **= exp

            scores += 1e-6  # To avoid the case that all inactive scores are 0
            scores = scores.masked_fill(mask_included, 0.)
            scores = scores.flatten(0, -2)  # [l3n_bs * num_sw, K * K]
            indices = torch.multinomial(scores, self.num_update, replacement=False)
            indices = indices.view(self.l3n_batch_sz, self.num_sw, -1)
            regrow_pos.scatter_(2, indices, True)

        return regrow_pos


    def _check_mask(self, mask, num_zeros):
        # For shared_mask_sw = True: [c_out, zone_sz * K * K]
        # For shared_mask_sw = False: [c_out * num_sw, zone_sz * K * K]
        mask = mask.view(-1, self.zone_sz * self.kernel_size * self.kernel_size)
        # Count number of zeros in each zone
        zeros_count = (mask == False).sum(dim=-1)
        
        # Check if number of zeros equals num_zeros for each vector
        assert torch.all(zeros_count == num_zeros), f'Number of zeros should be {num_zeros}, but got {zeros_count}'


    def calculate_mask_convergence(self):
        """
        Calculate mask convergence scores using min and mean methods.
        
        Returns:
            tuple: (min_score, mean_score) - convergence scores for this layer
        """
        if self.shared_mask_sw:
            if not self.use_hidden:
                return None, None
            mask = self.hidden_mask.flatten(2)  # [c_out, num_sw, zone_sz * K * K]
        else:
            mask = self.mask
        
        score = mask.float().mean(dim=1)  # Mean across sliding windows
        scores = score.topk(self.num_active)[0]
        
        min_score = scores.min().item()
        mean_score = scores.mean().item()
        
        return min_score, mean_score


    def _hidden_to_mask(self):
        assert self.shared_mask_sw and self.use_hidden
        hidden_avg = self.hidden_mask.flatten(2)  # [c_out, num_sw, zone_sz * K * K]
        hidden_avg = hidden_avg.float().mean(dim=1)  # [c_out, zone_sz * K * K]
        _, indices = torch.topk(hidden_avg, self.num_active, dim=1)
        mask = torch.zeros(
            self.c_out, self.zone_sz * self.kernel_size * self.kernel_size,
            dtype=torch.bool, device=self.hidden_mask.device)
        mask.scatter_(1, indices, True)
        self._check_mask(mask, self.num_zeros)
        return mask
    

    def _unfold_input(self, x):
        # [batch_sz, c_in*kernel_sz*kernel_sz, num_sw]
        return F.unfold(x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
        

