import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.utils import utils_l1_norm, utils_l2_norm, calculate_network_dims, new_utils_l2_norm
import math
# from hat.modules import HATLinear, HATConv2d, TaskIndexedLayerNorm
# from hat import HATPayload


class CReLU(nn.Module):

    def __init__(self, inplace=False):
        super(CReLU, self).__init__()

    def forward(self, x):
        if len(x.shape) == 2:
            x = torch.cat((x,-x),-1)
        elif len(x.shape) == 4:
            x = torch.cat((x,-x), 1)
        else: 
            raise f"{x.shpe} is invalid in CReLU"
        return F.relu(x)

class DeepFourier(nn.Module):
    def __init__(self):
       super(DeepFourier, self).__init__() 

    def forward(self, x):
        if len(x.shape) == 2:
            x = torch.cat((torch.cos(x), torch.sin(x)), -1)
        elif len(x.shape) == 4:
            x = torch.cat((torch.cos(x), torch.sin(x)), 1)
        else:
            raise f"{x.shpe} is invalid in DeepFourier"
        return x


class MixNormalResNet(nn.Module):
    def __init__(self, 
                    input_shape=(3, 32, 32),
                    num_classes=10,
                    activation='relu',
                    load_pretrained=False,
                    dropout_percentage=0.0,
                    disable_bn=False,
                    agent_type=None):
        super().__init__()
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.activation = activation
        self.load_pretrained = load_pretrained
        self.dropout_percentage = dropout_percentage
        self.disable_bn = disable_bn
        self.agent_type = agent_type

        double = True if activation == 'crelu' or activation == 'deepfourier' else False

        self.build_network(double=double)

        print(f'Pretrained model is loaded: {self.load_pretrained}')
        if load_pretrained and not double:
            self._load_pretrained_resnet18()

    def build_act(self, size, activation=None):
        if activation is None:
            activation = self.activation
        if activation == 'relu':
            return nn.ReLU()
        elif activation == 'crelu':
            return CReLU()
        elif activation == 'deepfourier':
            return DeepFourier()
        elif activation == 'prelu':
            return nn.PReLU(size)
    
    def build_network(self, double=False):
        self.layers = []
        self.layer_names = []
        self.last_filter_output = 512
        # tailored for 32*32 images (CIFAR100 and ImageNet Tiny)
        # BLOCK-1 (starting block) input=(224x224) output=(56x56)
        if self.input_shape[-1] == 32:
            self.conv1 = torch.nn.Conv2d(3, 32 if double else 64, kernel_size=3, stride=1, padding=1)
            self.batchnorm1 = torch.nn.Identity()
            self.maxpool1 = nn.MaxPool2d(kernel_size=(3,3), stride=(2,2), padding=(1,1))
        elif self.input_shape[-1] == 224:
            self.conv1 = nn.Conv2d(in_channels=3, out_channels=32 if double else 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=True) 
            
            if self.disable_bn:
                self.batchnorm1 = nn.Identity()
            else:
                self.batchnorm1 = nn.BatchNorm2d(32 if double else 64, eps=1e-05, momentum=0.1, affine=True) 
            
            self.maxpool1 = nn.MaxPool2d(kernel_size=(3,3), stride=(2,2), padding=(1,1))
        elif self.input_shape[-1] == 84:
            self.conv1 = nn.Conv2d(3, 32 if double else 64, kernel_size=7, stride=2, padding=3, bias=True)
            self.batchnorm1 = nn.BatchNorm2d(32 if double else 64, eps=1e-05, momentum=0.1, affine=True)
            self.maxpool1 = nn.Identity()
        else:
            raise ValueError(f"{self.input_shape[-1]} is not supported in ResNet")

        self.layers.append(self.conv1)
        self.layer_names.append('conv1')
        self.layers.append(None)

        self.act1 = self.build_act(32 if double else 64)
        
        # BLOCK-2 (1) input=(56x56) output = (56x56)
        self.conv2_1_1 = nn.Conv2d(in_channels=64, out_channels=32 if double else 64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True)
        self.act2_1_1 = self.build_act(32 if double else 64)
        self.layers.append(self.conv2_1_1)
        self.layer_names.append('conv2_1_1')
        self.layers.append(None)
        if self.disable_bn:
            self.batchnorm2_1_1 = nn.Identity()
        else:
            self.batchnorm2_1_1 = nn.BatchNorm2d(32 if double else 64, eps=1e-05, momentum=0.1, affine=True)
        self.conv2_1_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True)
        self.act2_1_2 = self.build_act(64, 'relu' if double else None)
        self.layers.append(self.conv2_1_2)
        self.layer_names.append('conv2_1_2')
        self.layers.append(None)
        if self.disable_bn:
            self.batchnorm2_1_2 = nn.Identity()
        else:
            self.batchnorm2_1_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
        self.dropout2_1 = nn.Dropout(p=self.dropout_percentage)
        # BLOCK-2 (2)
        self.conv2_2_1 = nn.Conv2d(in_channels=64, out_channels=32 if double else 64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True)
        self.act2_2_1 = self.build_act(32 if double else 64)
        self.layers.append(self.conv2_2_1)
        self.layer_names.append('conv2_2_1')
        self.layers.append(None)
        if self.disable_bn:
            self.batchnorm2_2_1 = nn.Identity()
        else:
            self.batchnorm2_2_1 = nn.BatchNorm2d(32 if double else 64, eps=1e-05, momentum=0.1, affine=True)
        
        self.conv2_2_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True)
        self.act2_2_2 = self.build_act(64, 'relu' if double else None)
        self.layers.append(self.conv2_2_2)
        self.layer_names.append('conv2_2_2')
        self.layers.append(None)
        if self.disable_bn:
            self.batchnorm2_2_2 = nn.Identity()
        else:
            self.batchnorm2_2_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
        self.dropout2_2 = nn.Dropout(p=self.dropout_percentage)
        
        # BLOCK-3 (1) input=(56x56) output = (28x28)
        self.conv3_1_1 = nn.Conv2d(in_channels=64, out_channels=64 if double else 128, kernel_size=(3,3), stride=(2,2), padding=(1,1), bias=True)
        self.act3_1_1 = self.build_act(64 if double else 128)
        self.layers.append(self.conv3_1_1)
        self.layer_names.append('conv3_1_1')
        self.layers.append(None)
        if self.disable_bn:
            self.batchnorm3_1_1 = nn.Identity()
        else:
            self.batchnorm3_1_1 = nn.BatchNorm2d(64 if double else 128, eps=1e-05, momentum=0.1, affine=True)
        self.conv3_1_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True)
        self.act3_1_2 = self.build_act(128, 'relu' if double else None)
        self.layers.append(self.conv3_1_2)
        self.layer_names.append('conv3_1_2')
        self.layers.append(None)
        if self.disable_bn:
            self.batchnorm3_1_2 = nn.Identity()
        else:
            self.batchnorm3_1_2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
        self.conv_concat_adjust_3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(1,1), stride=(2,2), padding=(0,0), bias=True)
        # self.layers.append(self.conv_concat_adjust_3)
        self.layer_names.append('conv_concat_adjust_3')
        # self.layers.append(None)
        if self.disable_bn:
            self.batchnorm_adjust_3 = nn.Identity()
        else:
            self.batchnorm_adjust_3 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
        self.dropout3_1 = nn.Dropout(p=self.dropout_percentage)
        # BLOCK-3 (2)
        self.conv3_2_1 = nn.Conv2d(in_channels=128, out_channels=64 if double else 128, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True)
        self.act3_2_1 = self.build_act(64 if double else 128)
        self.layers.append(self.conv3_2_1)
        self.layer_names.append('conv3_2_1')
        self.layers.append(None)
        if self.disable_bn:
            self.batchnorm3_2_1 = nn.Identity()
        else:
            self.batchnorm3_2_1 = nn.BatchNorm2d(64 if double else 128, eps=1e-05, momentum=0.1, affine=True)
        
        self.conv3_2_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True)
        self.act3_2_2 = self.build_act(128, 'relu' if double else None)
        self.layers.append(self.conv3_2_2)
        self.layer_names.append('conv3_2_2')
        self.layers.append(None)
        if self.disable_bn:
            self.batchnorm3_2_2 = nn.Identity()
        else:
            self.batchnorm3_2_2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
        self.dropout3_2 = nn.Dropout(p=self.dropout_percentage)
        
        # BLOCK-4 (1) input=(28x28) output = (14x14)
        self.conv4_1_1 = nn.Conv2d(in_channels=128, out_channels=128 if double else 256, kernel_size=(3,3), stride=(2,2), padding=(1,1), bias=True)
        self.act4_1_1 = self.build_act(128 if double else 256)
        self.layers.append(self.conv4_1_1)
        self.layer_names.append('conv4_1_1')
        self.layers.append(None)
        if self.disable_bn:
            self.batchnorm4_1_1 = nn.Identity()
        else:
            self.batchnorm4_1_1 = nn.BatchNorm2d(128 if double else 256, eps=1e-05, momentum=0.1, affine=True)
        self.conv4_1_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True)
        self.act4_1_2 = self.build_act(256, 'relu' if double else None)
        self.layers.append(self.conv4_1_2)
        self.layer_names.append('conv4_1_2')
        self.layers.append(None)
        if self.disable_bn:
            self.batchnorm4_1_2 = nn.Identity()
        else:
            self.batchnorm4_1_2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
        self.conv_concat_adjust_4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(1,1), stride=(2,2), padding=(0,0), bias=True)
        # self.layers.append(self.conv_concat_adjust_4)
        self.layer_names.append('conv_concat_adjust_4')
        # self.layers.append(None)
        if self.disable_bn:
            self.batchnorm_adjust_4 = nn.Identity()
        else:
            self.batchnorm_adjust_4 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
        self.dropout4_1 = nn.Dropout(p=self.dropout_percentage)
        # BLOCK-4 (2)
        self.conv4_2_1 = nn.Conv2d(in_channels=256, out_channels=128 if double else 256, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True)
        self.act4_2_1 = self.build_act(128 if double else 256)
        self.layers.append(self.conv4_2_1)
        self.layer_names.append('conv4_2_1')
        self.layers.append(None)
        if self.disable_bn:
            self.batchnorm4_2_1 = nn.Identity()
        else:
            self.batchnorm4_2_1 = nn.BatchNorm2d(128 if double else 256, eps=1e-05, momentum=0.1, affine=True)
        self.conv4_2_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True)
        self.act4_2_2 = self.build_act(256, 'relu' if double else None)
        self.layers.append(self.conv4_2_2)
        self.layer_names.append('conv4_2_2')
        self.layers.append(None)
        if self.disable_bn:
            self.batchnorm4_2_2 = nn.Identity()
        else:
            self.batchnorm4_2_2 = nn.BatchNorm2d(256,eps=1e-05, momentum=0.1, affine=True)
        self.dropout4_2 = nn.Dropout(p=self.dropout_percentage)
        
        # BLOCK-5 (1) input=(14x14) output = (7x7)
        self.conv5_1_1 = nn.Conv2d(in_channels=256, out_channels=256 if double else 512, kernel_size=(3,3), stride=(2,2), padding=(1,1), bias=True)
        self.act5_1_1 = self.build_act(256 if double else 512)
        self.layers.append(self.conv5_1_1)
        self.layer_names.append('conv5_1_1')
        self.layers.append(None)
        if self.disable_bn:
            self.batchnorm5_1_1 = nn.Identity()
        else:
            self.batchnorm5_1_1 = nn.BatchNorm2d(256 if double else 512, eps=1e-05, momentum=0.1, affine=True)
        self.conv5_1_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True)
        self.act5_1_2 = self.build_act(512, 'relu' if double else None)
        self.layers.append(self.conv5_1_2)
        self.layer_names.append('conv5_1_2')
        self.layers.append(None)
        if self.disable_bn:
            self.batchnorm5_1_2 = nn.Identity()
        else:
            self.batchnorm5_1_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
        self.conv_concat_adjust_5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(1,1), stride=(2,2), padding=(0,0), bias=True)
        # self.layers.append(self.conv_concat_adjust_5)
        self.layer_names.append('conv_concat_adjust_5')
        # self.layers.append(None)
        if self.disable_bn:
            self.batchnorm_adjust_5 = nn.Identity()
        else:
            self.batchnorm_adjust_5 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
        self.dropout5_1 = nn.Dropout(p=self.dropout_percentage)
        # BLOCK-5 (2)
        self.conv5_2_1 = nn.Conv2d(in_channels=512, out_channels=256 if double else 512, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True)
        self.act5_2_1 = self.build_act(256 if double else 512)
        self.layers.append(self.conv5_2_1)
        self.layer_names.append('conv5_2_1')
        self.layers.append(None)
        if self.disable_bn:
            self.batchnorm5_2_1 = nn.Identity()
        else:
            self.batchnorm5_2_1 = nn.BatchNorm2d(256 if double else 512, eps=1e-05, momentum=0.1, affine=True)
        self.conv5_2_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True)
        self.act5_2_2 = self.build_act(512, 'relu' if double else None)
        self.layers.append(self.conv5_2_2)
        self.layer_names.append('conv5_2_2')
        self.layers.append(None)
        if self.disable_bn:
            self.batchnorm5_2_2 = nn.Identity()
        else:
            self.batchnorm5_2_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
        self.dropout5_2 = nn.Dropout(p=self.dropout_percentage)
        
        # Final Block input=(7x7) 
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        self.fc1 = nn.Linear(in_features=512, out_features=self.num_classes)
        self.layers.append(self.fc1)
        self.layer_names.append('fc1')

        self.name_to_start_end = {

            "conv1": (0, 64),

            "conv2_1_1": (64, 128),
            "conv2_1_2": (128, 192),
            "conv2_2_1": (192, 256),
            "conv2_2_2": (256, 320),

            "conv3_1_1": (320, 448),
            "conv3_1_2": (448, 576),
            "conv_concat_adjust_3": (576, 704),
            "conv3_2_1": (704, 832),
            "conv3_2_2": (832, 960),

            "conv4_1_1": (960, 1216),
            "conv4_1_2": (1216, 1472),
            "conv_concat_adjust_4": (1472, 1728),
            "conv4_2_1": (1728, 1984),
            "conv4_2_2": (1984, 2240),

            "conv5_1_1": (2240, 2752),
            "conv5_1_2": (2752, 3264),
            "conv_concat_adjust_5": (3264, 3776),
            "conv5_2_1": (3776, 4288),
            "conv5_2_2": (4288, 4800),

            "fc1": (4800, 4800 + self.num_classes)
        }

        self.layer_channels = {
            name: end - start
            for name, (start, end) in self.name_to_start_end.items()
        }

        self.name_layers = list(self.name_to_start_end.keys())

        self.total_nodes = 4800 + self.num_classes    

        
    def _load_pretrained_resnet18(self):
        import torchvision.models as models

        # Load torchvision's pretrained ResNet-18
        try:
            # newer torchvision
            resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        except Exception:
            # older torchvision
            resnet = models.resnet18(pretrained=True)

        # Make sure we're on the same device / dtype
        device = next(self.parameters()).device
        resnet = resnet.to(device)

        # --------- Convolution layers (skip BLOCK-1) ---------
        conv_pairs = [
            # BLOCK-2  (ResNet layer1)
            (resnet.layer1[0].conv1, self.conv2_1_1),
            (resnet.layer1[0].conv2, self.conv2_1_2),
            (resnet.layer1[1].conv1, self.conv2_2_1),
            (resnet.layer1[1].conv2, self.conv2_2_2),

            # BLOCK-3  (ResNet layer2)
            (resnet.layer2[0].conv1,        self.conv3_1_1),
            (resnet.layer2[0].conv2,        self.conv3_1_2),
            (resnet.layer2[0].downsample[0], self.conv_concat_adjust_3),
            (resnet.layer2[1].conv1,        self.conv3_2_1),
            (resnet.layer2[1].conv2,        self.conv3_2_2),

            # BLOCK-4  (ResNet layer3)
            (resnet.layer3[0].conv1,        self.conv4_1_1),
            (resnet.layer3[0].conv2,        self.conv4_1_2),
            (resnet.layer3[0].downsample[0], self.conv_concat_adjust_4),
            (resnet.layer3[1].conv1,        self.conv4_2_1),
            (resnet.layer3[1].conv2,        self.conv4_2_2),

            # BLOCK-5  (ResNet layer4)
            (resnet.layer4[0].conv1,        self.conv5_1_1),
            (resnet.layer4[0].conv2,        self.conv5_1_2),
            (resnet.layer4[0].downsample[0], self.conv_concat_adjust_5),
            (resnet.layer4[1].conv1,        self.conv5_2_1),
            (resnet.layer4[1].conv2,        self.conv5_2_2),
        ]

        for src, dst in conv_pairs:
            dst.weight.data.copy_(src.weight.data)
            if dst.bias is not None and src.bias is not None:
                dst.bias.data.copy_(src.bias.data)

        if not self.disable_bn:
            # --------- BatchNorm layers ---------
            bn_pairs = [
                # BLOCK-2 (layer1)
                (resnet.layer1[0].bn1,          self.batchnorm2_1_1),
                (resnet.layer1[0].bn2,          self.batchnorm2_1_2),
                (resnet.layer1[1].bn1,          self.batchnorm2_2_1),
                (resnet.layer1[1].bn2,          self.batchnorm2_2_2),

                # BLOCK-3 (layer2)
                (resnet.layer2[0].bn1,          self.batchnorm3_1_1),
                (resnet.layer2[0].bn2,          self.batchnorm3_1_2),
                (resnet.layer2[0].downsample[1], self.batchnorm_adjust_3),
                (resnet.layer2[1].bn1,          self.batchnorm3_2_1),
                (resnet.layer2[1].bn2,          self.batchnorm3_2_2),

                # BLOCK-4 (layer3)
                (resnet.layer3[0].bn1,          self.batchnorm4_1_1),
                (resnet.layer3[0].bn2,          self.batchnorm4_1_2),
                (resnet.layer3[0].downsample[1], self.batchnorm_adjust_4),
                (resnet.layer3[1].bn1,          self.batchnorm4_2_1),
                (resnet.layer3[1].bn2,          self.batchnorm4_2_2),

                # BLOCK-5 (layer4)
                (resnet.layer4[0].bn1,          self.batchnorm5_1_1),
                (resnet.layer4[0].bn2,          self.batchnorm5_1_2),
                (resnet.layer4[0].downsample[1], self.batchnorm_adjust_5),
                (resnet.layer4[1].bn1,          self.batchnorm5_2_1),
                (resnet.layer4[1].bn2,          self.batchnorm5_2_2),
            ]

        for src, dst in bn_pairs:
            dst.weight.data.copy_(src.weight.data)
            dst.bias.data.copy_(src.bias.data)
            dst.running_mean.data.copy_(src.running_mean.data)
            dst.running_var.data.copy_(src.running_var.data)
            dst.num_batches_tracked.data.copy_(src.num_batches_tracked.data)

        # --------- AdaptiveAvgPool2d ---------
        # self.avgpool is already constructed as AdaptiveAvgPool2d((1, 1)),
        # and this layer has no learnable parameters, so nothing to copy.

        # --------- Fully connected layer ---------
        # Load fc weights if output size matches (e.g. num_classes == 1000).
        if (self.fc1.weight.shape == resnet.fc.weight.shape and
            self.fc1.bias.shape == resnet.fc.bias.shape):
            self.fc1.weight.data.copy_(resnet.fc.weight.data)
            self.fc1.bias.data.copy_(resnet.fc.bias.data)
        # If shapes don't match (e.g. different num_classes), we leave fc1
        # with its existing initialization so it can be trained from scratch.

    def forward(self, x):
        self.activations = {}
        self.activations_for_redo = {}
        # block 1 --> Starting block
        x = self.act1(self.batchnorm1(self.conv1(x)))
        self.activations['conv1'] = x.detach()
        self.activations_for_redo['conv1'] = (x, 'conv', 'conv')
        op1 = self.maxpool1(x)
        
        
        # block2 - 1
        x = self.act2_1_1(self.batchnorm2_1_1(self.conv2_1_1(op1)))    # conv2_1 
        self.activations['conv2_1_1'] = x.detach()
        self.activations_for_redo['conv2_1_1'] = (x, 'conv', 'conv')
        x = self.batchnorm2_1_2(self.conv2_1_2(x))                 # conv2_1
        self.activations['conv2_1_2'] = x.detach()
        self.activations_for_redo['conv2_1_2'] = (x, 'conv', 'conv')
        x = self.dropout2_1(x)
        # block2 - Adjust - No adjust in this layer as dimensions are already same
        # block2 - Concatenate 1
        op2_1 = self.act2_1_2(x + op1)
        # block2 - 2
        x = self.act2_2_1(self.batchnorm2_2_1(self.conv2_2_1(op2_1)))  # conv2_2 
        self.activations['conv2_2_1'] = x.detach()
        self.activations_for_redo['conv2_2_1'] = (x, 'conv', 'conv')
        #breakpoint()
        x = self.batchnorm2_2_2(self.conv2_2_2(x))                 # conv2_2
        self.activations['conv2_2_2'] = x.detach()
        self.activations_for_redo['conv2_2_2'] = (x, 'conv', 'conv')
        x = self.dropout2_2(x)
        # op - block2
        op2 = self.act2_2_2(x + op2_1)
    
        
        # block3 - 1[Convolution block]
        x = self.act3_1_1(self.batchnorm3_1_1(self.conv3_1_1(op2)))    # conv3_1
        self.activations['conv3_1_1'] = x.detach()
        self.activations_for_redo['conv3_1_1'] = (x, 'conv', 'conv')
        x = self.batchnorm3_1_2(self.conv3_1_2(x))                 # conv3_1
        #breakpoint()
        self.activations['conv3_1_2'] = x.detach()
        self.activations_for_redo['conv3_1_2'] = (x, 'conv', 'conv')
        x = self.dropout3_1(x)
        # block3 - Adjust
        op2 = self.conv_concat_adjust_3(op2) # SKIP CONNECTION
        op2 = self.batchnorm_adjust_3(op2)
        #self.activations['conv_concat_adjust_3'] = op2.detach()
        self.activations_for_redo['conv_concat_adjust_3'] = (op2, 'conv', 'conv') 
        # block3 - Concatenate 1
        #breakpoint()
        op3_1 = self.act3_1_2(x + op2)
        # block3 - 2[Identity Block]
        x = self.act3_2_1(self.batchnorm3_2_1(self.conv3_2_1(op3_1)))  # conv3_2
        self.activations['conv3_2_1'] = x.detach() 
        self.activations_for_redo['conv3_2_1'] = (x, 'conv', 'conv') 
        x = self.batchnorm3_2_2(self.conv3_2_2(x)) 
        self.activations['conv3_2_2'] = x.detach()   
        self.activations_for_redo['conv3_2_2'] = (x, 'conv', 'conv')               # conv3_2 
        x = self.dropout3_2(x)
        # op - block3
        op3 = self.act3_2_2(x + op3_1)
        
        
        # block4 - 1[Convolition block]
        x = self.act4_1_1(self.batchnorm4_1_1(self.conv4_1_1(op3)))    # conv4_1
        self.activations['conv4_1_1'] = x.detach()
        self.activations_for_redo['conv4_1_1'] = (x, 'conv', 'conv')   
        x = self.batchnorm4_1_2(self.conv4_1_2(x))   
        self.activations['conv4_1_2'] = x.detach()              # conv4_1
        self.activations_for_redo['conv4_1_2'] = (x, 'conv', 'conv')   
        x = self.dropout4_1(x)
        # block4 - Adjust
        op3 = self.conv_concat_adjust_4(op3) # SKIP CONNECTION
        op3 = self.batchnorm_adjust_4(op3)
        #self.activations['conv_concat_adjust_4'] = op3.detach()  
        self.activations_for_redo['conv_concat_adjust_4'] = (op3, 'conv', 'conv')   
        # block4 - Concatenate 1
        op4_1 = self.act4_1_2(x + op3)
        # block4 - 2[Identity Block]
        #breakpoint()
        x = self.act4_2_1(self.batchnorm4_2_1(self.conv4_2_1(op4_1)))  # conv4_2
        self.activations['conv4_2_1'] = x.detach() 
        self.activations_for_redo['conv4_2_1'] = (x, 'conv', 'conv') 
        x = self.batchnorm4_2_2(self.conv4_2_2(x))                 # conv4_2
        self.activations['conv4_2_2'] = x.detach()
        self.activations_for_redo['conv4_2_2'] = (x, 'conv', 'conv') 
        x = self.dropout4_2(x)
        # op - block4
        op4 = self.act4_2_2(x + op4_1)

        
        # block5 - 1[Convolution Block]
        x = self.act5_1_1(self.batchnorm5_1_1(self.conv5_1_1(op4)))    # conv5_1
        self.activations['conv5_1_1'] = x.detach()
        self.activations_for_redo['conv5_1_1'] = (x, 'conv', 'conv') 
        x = self.batchnorm5_1_2(self.conv5_1_2(x))                 # conv5_1
        self.activations['conv5_1_2'] = x.detach()
        self.activations_for_redo['conv5_1_2'] = (x, 'conv', 'fc') 
        x = self.dropout5_1(x)
        # block5 - Adjust
        op4 = self.conv_concat_adjust_5(op4) # SKIP CONNECTION
        op4 = self.batchnorm_adjust_5(op4)
        #self.activations['conv_concat_adjust_5'] = op4.detach()
        self.activations_for_redo['conv_concat_adjust_5'] = (op4, 'conv', 'fc') 
        # block5 - Concatenate 1
        op5_1 = self.act5_1_2(x + op4)
        # block5 - 2[Identity Block]
        x = self.act5_2_1(self.batchnorm5_2_1(self.conv5_2_1(op5_1)))  # conv5_2
        self.activations['conv5_2_1'] = x.detach()
        self.activations_for_redo['conv5_2_1'] = (x, 'conv', 'conv') 
        x = self.batchnorm5_2_2(self.conv5_2_2(x))                 # conv5_2
        self.activations['conv5_2_2'] = x.detach()
        self.activations_for_redo['conv5_2_2'] = (x, 'conv', 'fc') 
        x = self.dropout5_2(x)
        # op - block5
        op5 = self.act5_2_2(x + op5_1)

        # FINAL BLOCK - classifier 
        x = self.avgpool(op5)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        #self.activations['fc1'] = x.detach()
        return x
    
    def get_model_weights_l2_norm(self):
        return utils_l2_norm(self.named_parameters())

    def compute_l1_norm(self):
        return utils_l1_norm(self.named_parameters())
    
    def compute_l2_norm(self):
        return new_utils_l2_norm(self.named_parameters())
    
    def compute_total_params(self):
        # Get the total number of parameters in the neural network
        # NOT including the layer_norm parameters or init params.
        total_params = 0.
        
        for name, param in self.named_parameters():
            if 'layer_norm' not in name and \
                'init_params' not in name and \
                    'original_last_layer_params' not in name:
                    total_params += param.numel()
                    
        return total_params
    
    
