import os
import sys
import torch.nn.functional as F
import torch.nn as nn
import torch
import torch.nn.init as init
import numpy as np

def extend_image_append_label(x,y,opts,num_ext=1):
    x_ce = torch.nn.functional.pad(x,(num_ext,num_ext,num_ext,num_ext)) # Create a BSxIMsizexIMsize empty matrix to embed the label around the image
    ll = torch.zeros(x.shape[0],x_ce.shape[-1],x_ce.shape[-1]).to(opts.device)
    if y.dim() == 1:
        y = y.unsqueeze(0)
        y = y.expand(x.shape[0], -1)
    # y_int = y.argmax(dim=1)
    # y_extend = torch.zeros(y.shape[0], 34).to(opts.device)
    # for i in range(y.shape[0]):
    #     y_extend[i, y_int[i]:y_int[i] + y.shape[1]] = y[i]
        
    for i in range(num_ext):
        ll[:,12:22,0+i] = y
        ll[:,12:22,x_ce.shape[-1]-i-1] = y
        ll[:,0+i,12:22] = y
        ll[:,33-i,12:22] = y
    ll = ll.unsqueeze(1).repeat(1, 3, 1, 1)
    x_ce = x_ce + ll # Insert the label next to image
    return x_ce

def partition_and_reorder(images, n_blocks):
    # Assuming images size is (batch_size, 3, 32, 32)
    assert images.shape[-2] % n_blocks == 0 and images.shape[-1] % n_blocks == 0, \
        "Image size should be divisible by the number of blocks"

    block_size = images.shape[-1] // n_blocks

    # Partition the image into blocks using the unfold function
    # The size of blocks becomes (batch_size, 3, n_blocks, n_blocks, block_size, block_size)
    blocks = images.unfold(-2, block_size, block_size).unfold(-1, block_size, block_size)

    # Reshape the blocks tensor for easier shuffling
    blocks = blocks.reshape(*blocks.shape[:-4], -1, block_size, block_size)

    # Create a deterministic permutation of indices
    # Swap adjacent indices
    n = blocks.shape[-3]
    indices = torch.arange(n).view(-1, 2)
    indices = torch.cat((indices[:,1:], indices[:,:1]), axis=1).view(-1)

    # Use the permuted indices to reorder the blocks
    reordered_blocks = blocks[:, :, indices, :, :]

    # Reshape the reordered blocks back to the image shape
    reordered_image = reordered_blocks.view(*reordered_blocks.shape[:-3], n_blocks, n_blocks, block_size, block_size)
    
    # Rearrange axes to get back to (batch_size, 3, 32, 32) shape
    reordered_image = reordered_image.permute(0, 1, 2, 4, 3, 5).contiguous().view(*images.shape)

    return reordered_image



# it is an unshared convolutional layer
class LocallyConnected2d(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, output_size=32, kernel_size=11):
        super(LocallyConnected2d, self).__init__()

        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.output_size = output_size

        weight_shape = (out_channels, in_channels, kernel_size, kernel_size, output_size, output_size)
        self.weights = nn.Parameter(0.001 + 0.001*torch.randn(*weight_shape))
        # for i in range(out_channels):
        #     for j in range(in_channels):
        #         init.kaiming_normal_(self.weights[i, j, :, :, :, :])
        
        
        self.bias = nn.Parameter(torch.randn(out_channels, output_size, output_size))
        #self.bias = nn.Parameter(torch.zeros(out_channels, output_size, output_size))

    def forward(self, x, y, opts):
        padding_width = int((self.output_size + self.kernel_size - 1 - x.shape[-2])/2) # this is for odd kernel size otherwise one should assign left and right padding width separately
        #x_padded = torch.nn.functional.pad(x, (padding_width, padding_width, padding_width, padding_width))
        x_padded = extend_image_append_label(x, y, opts, num_ext=padding_width)
        in_channels, input_size, _, = x_padded.shape[-3:]
        batch_size = x.shape[0]
        
        assert input_size == self.output_size + self.kernel_size - 1

        x_unf = torch.nn.functional.unfold(x_padded, self.kernel_size).view(batch_size, in_channels, self.kernel_size, self.kernel_size, self.output_size, self.output_size)

        out_unf = (self.weights.unsqueeze(0) * x_unf.unsqueeze(1)).sum([2, 3, 4]) + self.bias.unsqueeze(0)
        
        return out_unf.view(batch_size, self.out_channels, self.output_size, self.output_size)

class locally_block(nn.Module):
    def __init__(self, out_size = 34, kernel_size = 11, first = False, in_size = 32):
        super().__init__()
        #self.scale_factor = 1/(32*32*11*11)
        self.local_conn1 = LocallyConnected2d(3, 3, out_size, kernel_size)
        self.bn = nn.BatchNorm2d(3)
        self.norm = nn.LayerNorm([3,in_size,in_size], elementwise_affine=False)
        self.relu = ReLU_full_grad.apply
        self.first = first
        #self.relu = nn.ReLU(True)
        self.dropout = nn.Dropout(0.2)
    
    def forward(self, h, y, opts):
        # if self.first:
        #     h = self.norm(h)
        # if self.first == False:
        #     h = self.norm(h)
        h = self.norm(h)
        h = self.local_conn1(h, y, opts)
        #h = self.bn(h)  
        h = self.relu(h)
        h = partition_and_reorder(h, 4)
        #h = self.dropout(h)
        self.h = h
        return h

class ReLU_full_grad(torch.autograd.Function):
    """ ReLU activation function that passes through the gradient irrespective of its input value. """

    @staticmethod
    def forward(ctx, input):
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.clone()   


# BP baseline network 2 hidden layers
class MLP_Net_Receptive(nn.Module):
    def __init__(self):
        super(MLP_Net_Receptive, self).__init__()
        self.local_conn1 = LocallyConnected2d(3, 3, 32, 11)
        self.bn1 = nn.BatchNorm2d(3)
        self.relu = nn.ReLU(True)
        self.local_conn2 = LocallyConnected2d(3, 3, 32, 11)
        self.bn2 = nn.BatchNorm2d(3)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(32 * 32 * 3, 10)
        self.dropout1 = nn.Dropout(0.5)
        self.dropout2 = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.local_conn1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout1(x)
        x = self.local_conn2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.dropout2(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x
    
class MLP_FF_Receptive(nn.Module):
    def __init__(self):
        super(MLP_FF_Receptive, self).__init__()
        blocks = []
        blocks.append(locally_block(out_size=32, first=True, in_size=32))
        blocks.append(locally_block(out_size=32))
        #blocks.append(locally_block(out_size=32))
        # exp
        #blocks.append(locally_block(out_size=32))
        self.h = []
        self.blocks = nn.Sequential(*blocks)
        
    def forward(self, x, y, opts):
        x = self.blocks[0](x, y, opts)
        x = self.blocks[1](x, y, opts)
        #x = self.blocks[2](x, y, opts)
        
        #x = self.blocks[2](x)
        hs = [b.h.view(b.h.shape[0],-1) for b in self.blocks.children()]
        #hs = [self.blocks[1].h.view(self.blocks[0].h.shape[0],-1)]
        return torch.cat(hs, dim=1)

# define mu delta for layer norm
mu = 10
delta = 1 

def set_mu_delta(mu_, delta_):
    global mu, delta
    mu = mu_
    delta = delta_  

class Conv_block(nn.Module):
    def __init__(self, in_channels, output_sz, input_sz, out_channels=32, kernel_size=3, stride=1, padding=0, init_norm=True, class_num=10, pool=False):
        super().__init__()
        self.init_norm = init_norm
        self.pool_en = pool
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.lc = nn.Linear(class_num, output_sz*output_sz*out_channels)
        if pool:
            self.maxpool = nn.MaxPool2d(2, 2)
        if init_norm:
            self.norm = nn.LayerNorm([in_channels,input_sz,input_sz], elementwise_affine=False)
        
        #self.bn = nn.BatchNorm2d(out_channels)
        self.relu = ReLU_full_grad.apply
        

    def forward(self, x, y, opts):
        if self.init_norm:
            x = self.norm(x)
        x = self.conv1(x) 
        x = self.relu(x)
        #x = nn.Dropout(0.2)(x)
        c = self.lc(y)
        
        batch_size, channels, height, width = x.shape
        c_reshaped = c.view(batch_size, channels, height, width)
        f = x*c_reshaped
        self.h = f
        
        if self.pool_en:
            x = self.maxpool(x)
            
        return f, x

class Conv_block_v2(nn.Module):
    def __init__(self, in_channels, output_sz, input_sz, out_channels=32, kernel_size=3, stride=1, padding=0, init_norm=True, class_num=10, pool=False, encoded = 0, downsample=False):
        super().__init__()
        self.init_norm = init_norm
        self.pool_en = pool
        self.encoded = encoded
        self.downsample = downsample
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding)
        if encoded:
            self.lc = nn.Linear(encoded*encoded, output_sz*output_sz*out_channels)
        else:
            self.lc = nn.Linear(class_num, output_sz*output_sz*out_channels)
        if downsample:
            self.conv4 = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
        
        if pool:
            self.maxpool = nn.MaxPool2d(2, 2)
        if init_norm:
            self.norm = nn.LayerNorm([in_channels,input_sz,input_sz], elementwise_affine=True)
            self.norm2 = nn.LayerNorm([out_channels,output_sz,output_sz], elementwise_affine=True)
            self.norm.weight.data.fill_(delta)
            self.norm.bias.data.fill_(mu)
            self.norm2.weight.data.fill_(delta)
            self.norm2.bias.data.fill_(mu)
            self.norm.weight.requires_grad = False
            self.norm.bias.requires_grad = False
            self.norm2.weight.requires_grad = False
            self.norm2.bias.requires_grad = False
        
        #self.bn = nn.BatchNorm2d(out_channels)
        self.relu = ReLU_full_grad.apply

    def forward(self, x, y, opts):
        if self.init_norm:
            x = self.norm(x)
        if self.downsample:
            x_d = self.conv4(x)
        x = self.conv1(x)
        x = self.relu(x)
        if self.init_norm:
            x = self.norm2(x) 
        x = self.conv2(x) 
        x = self.relu(x)
        #x = nn.Dropout(0.2)(x)
        
        if self.encoded:
            if y.shape[1] != self.encoded*self.encoded:
                repeat_factor = self.encoded*self.encoded // y.shape[1]
                remainder = self.encoded*self.encoded % y.shape[1]
                y_expanded = y.repeat(1, repeat_factor).reshape(y.shape[0], -1)
                y_remainder = y[:, :remainder].reshape(y.shape[0], -1)
                y = torch.cat([y_expanded, y_remainder], dim=1)
        c = self.lc(y)
        #c = self.relu(c)
        #c = nn.Dropout(0.2)(c)
        
        batch_size, channels, height, width = x.shape
        c_reshaped = c.view(batch_size, channels, height, width)
        
        f = x*c_reshaped
        self.h = f
        
        if self.downsample:
            x = x + x_d
            x = self.relu(x)
        
        if self.pool_en:
            x = self.maxpool(x)
        # x = nn.Dropout(0.2)(x)
        # f = nn.Dropout(0.2)(f)    
        return f, x
    
class Conv_block_v3(nn.Module):
    def __init__(self, in_channels, output_sz, input_sz, out_channels=32, kernel_size=3, stride=1, padding=0, init_norm=True, class_num=10, pool=False, encoded = 0, downsample=False):
        super().__init__()
        self.init_norm = init_norm
        self.pool_en = pool
        self.encoded = encoded
        self.downsample = downsample
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding)
        self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding)

        if downsample:
            self.conv4 = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
    
        if encoded!=0:
            self.lc = nn.Linear(encoded*encoded, output_sz*output_sz*out_channels)
        else:
            self.lc = nn.Linear(class_num, output_sz*output_sz*out_channels)
        if pool:
            self.maxpool = nn.MaxPool2d(2, 2)
        if init_norm:
            self.norm = nn.LayerNorm([in_channels,input_sz,input_sz], elementwise_affine=True)
            self.norm2 = nn.LayerNorm([out_channels,output_sz,output_sz], elementwise_affine=True)
            self.norm.weight.data.fill_(delta)
            self.norm.bias.data.fill_(mu)
            self.norm2.weight.data.fill_(delta)
            self.norm2.bias.data.fill_(mu)
            self.norm.weight.requires_grad = False
            self.norm.bias.requires_grad = False
            self.norm2.weight.requires_grad = False
            self.norm2.bias.requires_grad = False
        
        #self.bn = nn.BatchNorm2d(out_channels)
        self.relu = ReLU_full_grad.apply

    def forward(self, x, y, opts):
        if self.init_norm:
            x = self.norm(x)
        
        if self.downsample:
            x_d = self.conv4(x)
        
        x = self.conv1(x)
        x = self.relu(x) 
        if self.init_norm:
            x = self.norm2(x) 
        x = self.conv2(x)
        
        if self.downsample:
            x = x + x_d
        x = self.relu(x)
        if self.init_norm:
            x = self.norm2(x) 
        x = self.conv3(x)
        x = self.relu(x)
        
        #x = nn.Dropout(0.2)(x)
        
        if self.encoded!=0:
            if y.shape[1] != self.encoded*self.encoded:
                repeat_factor = self.encoded*self.encoded // y.shape[1]
                remainder = self.encoded*self.encoded % y.shape[1]
                y_expanded = y.repeat(1, repeat_factor).reshape(y.shape[0], -1)
                y_remainder = y[:, :remainder].reshape(y.shape[0], -1)
                y = torch.cat([y_expanded, y_remainder], dim=1)
        c = self.lc(y)
        #c = self.relu(c)
        
        batch_size, channels, height, width = x.shape
        c_reshaped = c.view(batch_size, channels, height, width)
        
        
        f = x*c_reshaped
        self.h = f
        
        x = self.relu(x)
        if self.pool_en:
            x = self.maxpool(x)  
        # x = nn.Dropout(0.2)(x)
        # f = nn.Dropout(0.2)(f)     
        return f, x

class Conv_FF_model_v2(nn.Module):
    def __init__(self, combo=0):
        super(Conv_FF_model_v2, self).__init__()
        blocks = []
        self.combo = combo
        if combo==0:
            blocks.append(Conv_block_v2(in_channels=3, output_sz=32, input_sz=32, out_channels=32, kernel_size=3, stride=1, padding=1, init_norm=True, pool=True))
            blocks.append(Conv_block_v2(in_channels=32, output_sz=16, input_sz=16, out_channels=64, kernel_size=3, stride=1, padding=1, init_norm=True))
        elif combo==1:
            blocks.append(Conv_block_v2(in_channels=3, output_sz=32, input_sz=32, out_channels=32, kernel_size=3, stride=1, padding=1, init_norm=True, pool=True))
            blocks.append(Conv_block(in_channels=32, output_sz=16, input_sz=16, out_channels=64, kernel_size=3, stride=1, padding=1, init_norm=True))
            blocks.append(Conv_block(in_channels=64, output_sz=16, input_sz=16, out_channels=64, kernel_size=3, stride=1, padding=1, init_norm=True, pool=True))
            blocks.append(Conv_block(in_channels=64, output_sz=8, input_sz=8, out_channels=128, kernel_size=3, stride=1, padding=1, init_norm=True))
            blocks.append(Conv_block(in_channels=128, output_sz=8, input_sz=8, out_channels=128, kernel_size=3, stride=1, padding=1, init_norm=True))
        elif combo==2:
            blocks.append(Conv_block_v2(in_channels=3, output_sz=32, input_sz=32, out_channels=32, kernel_size=3, stride=1, padding=1, init_norm=True, pool=True))
            blocks.append(Conv_block(in_channels=32, output_sz=16, input_sz=16, out_channels=64, kernel_size=3, stride=1, padding=1, init_norm=True))
            blocks.append(Conv_block(in_channels=64, output_sz=16, input_sz=16, out_channels=64, kernel_size=3, stride=1, padding=1, init_norm=True, pool=True))
            blocks.append(Conv_block(in_channels=64, output_sz=8, input_sz=8, out_channels=128, kernel_size=3, stride=1, padding=1, init_norm=True))
            blocks.append(Conv_block(in_channels=128, output_sz=8, input_sz=8, out_channels=128, kernel_size=3, stride=1, padding=1, init_norm=True, pool=True))
            blocks.append(Conv_block(in_channels=128, output_sz=4, input_sz=4, out_channels=256, kernel_size=3, stride=1, padding=1, init_norm=True))
            blocks.append(Conv_block(in_channels=256, output_sz=4, input_sz=4, out_channels=256, kernel_size=3, stride=1, padding=1, init_norm=True))
        elif combo==3:
            blocks.append(Conv_block_v2(in_channels=3, output_sz=32, input_sz=32, out_channels=32, kernel_size=3, stride=1, padding=1, init_norm=True, pool=True))
            blocks.append(Conv_block_v2(in_channels=32, output_sz=16, input_sz=16, out_channels=64, kernel_size=3, stride=1, padding=1, init_norm=True, pool=True))
            blocks.append(Conv_block(in_channels=64, output_sz=8, input_sz=8, out_channels=128, kernel_size=3, stride=1, padding=1, init_norm=True))
            blocks.append(Conv_block(in_channels=128, output_sz=8, input_sz=8, out_channels=128, kernel_size=3, stride=1, padding=1, init_norm=True, pool=True))
            blocks.append(Conv_block(in_channels=128, output_sz=4, input_sz=4, out_channels=256, kernel_size=3, stride=1, padding=1, init_norm=True))
            blocks.append(Conv_block(in_channels=256, output_sz=4, input_sz=4, out_channels=256, kernel_size=3, stride=1, padding=1, init_norm=True))
            blocks.append(Conv_block(in_channels=256, output_sz=4, input_sz=4, out_channels=256, kernel_size=3, stride=1, padding=1, init_norm=True))
            blocks.append(Conv_block(in_channels=256, output_sz=4, input_sz=4, out_channels=256, kernel_size=3, stride=1, padding=1, init_norm=True))
        elif combo==4:
            blocks.append(Conv_block_v2(in_channels=3, output_sz=32, input_sz=32, out_channels=64, kernel_size=3, stride=1, padding=1, init_norm=True, pool=True))
            blocks.append(Conv_block_v2(in_channels=64, output_sz=16, input_sz=16, out_channels=128, kernel_size=3, stride=1, padding=1, init_norm=True, pool=True))
            blocks.append(Conv_block_v3(in_channels=128, output_sz=8, input_sz=8, out_channels=256, kernel_size=3, stride=1, padding=1, init_norm=True, pool=True))
            blocks.append(Conv_block_v3(in_channels=256, output_sz=4, input_sz=4, out_channels=512, kernel_size=3, stride=1, padding=1, init_norm=True, pool=True))
            #blocks.append(Conv_block_v3(in_channels=512, output_sz=2, input_sz=2, out_channels=512, kernel_size=3, stride=1, padding=1, init_norm=True))
        elif combo==5:
            blocks.append(Conv_block_v3(in_channels=3, output_sz=32, input_sz=32, out_channels=64, kernel_size=3, stride=1, padding=1, init_norm=True, pool=True, downsample=True))
            blocks.append(Conv_block_v3(in_channels=64, output_sz=16, input_sz=16, out_channels=128, kernel_size=3, stride=1, padding=1, init_norm=True, pool=True, downsample=True))
            blocks.append(Conv_block_v3(in_channels=128, output_sz=8, input_sz=8, out_channels=256, kernel_size=3, stride=1, padding=1, init_norm=True, pool=True, downsample=True))
            blocks.append(Conv_block_v3(in_channels=256, output_sz=4, input_sz=4, out_channels=512, kernel_size=3, stride=1, padding=1, init_norm=True, pool=True, downsample=True))
            #blocks.append(Conv_block_v3(in_channels=512, output_sz=2, input_sz=2, out_channels=512, kernel_size=3, stride=1, padding=1, init_norm=True))
            
        self.blocks = nn.Sequential(*blocks)
        
    def forward(self, x, y, opts, diff_res):
        if self.combo==0:
            _, x = self.blocks[0](x, y, opts)
            _, x = self.blocks[1](x, y, opts)
        elif self.combo==1:
            _, x = self.blocks[0](x, y, opts)
            _, x = self.blocks[1](x, y, opts)
            _, x = self.blocks[2](x, y, opts)
            _, x = self.blocks[3](x, y, opts)
            _, x = self.blocks[4](x, y, opts)
        elif self.combo==2:
            _, x = self.blocks[0](x, y, opts)
            _, x = self.blocks[1](x, y, opts)
            _, x = self.blocks[2](x, y, opts)
            _, x = self.blocks[3](x, y, opts)
            _, x = self.blocks[4](x, y, opts)
            _, x = self.blocks[5](x, y, opts)
            _, x = self.blocks[6](x, y, opts)
        elif self.combo==3:
            _, x = self.blocks[0](x, y, opts)
            _, x = self.blocks[1](x, y, opts)
            _, x = self.blocks[2](x, y, opts)
            _, x = self.blocks[3](x, y, opts)
            _, x = self.blocks[4](x, y, opts)
            _, x = self.blocks[5](x, y, opts)
            _, x = self.blocks[6](x, y, opts)
            _, x = self.blocks[7](x, y, opts)
        elif self.combo==4:
            _, x = self.blocks[0](x, y, opts)
            _, x = self.blocks[1](x, y, opts)
            _, x = self.blocks[2](x, y, opts)
            _, x = self.blocks[3](x, y, opts)
            #_, x = self.blocks[4](x, y, opts)
        elif self.combo==5:
            _, x = self.blocks[0](x, y, opts)
            _, x = self.blocks[1](x, y, opts)
            _, x = self.blocks[2](x, y, opts)
            _, x = self.blocks[3](x, y, opts)
        #hs = [b.h.view(b.h.shape[0],-1) for b in self.blocks.children()]
        base = np.min(diff_res)
        hs = []
        for i, b in enumerate(diff_res):
            hs.append(self.blocks[i].h.view(self.blocks[i].h.shape[0],-1).pow(2).sum(1)*b/base)
        return torch.cat([t.unsqueeze(1) for t in hs], dim=1)
    
        
class Conv_block_bp(nn.Module):
    def __init__(self, combo=0):
        super(Conv_block_bp, self).__init__()
        #mini VGG structure
        self.combo = combo
        if combo == 0:  # shrunk Mini VGG 0.6M
            self.features = nn.Sequential(
                nn.Conv2d(3, 32, 3, 1, 1),
                nn.ReLU(),
                nn.Conv2d(32, 32, 3, 1, 1),
                nn.ReLU(),
                nn.MaxPool2d(2, 2),
                #nn.Dropout(0.2),
                nn.Conv2d(32, 64, 3, 1, 1),
                nn.ReLU(),
                nn.Conv2d(64, 64, 3, 1, 1),
                nn.ReLU(),
                #nn.Dropout(0.2),
                nn.Flatten()
            )
            self.classifier = nn.Sequential(
                nn.Linear(64*16*16, 48),
                nn.ReLU(),
                #nn.Dropout(0.5),
                nn.Linear(48, 10)
            )
        elif combo == 1: # 1.1M
            self.features = nn.Sequential(
                nn.Conv2d(3, 32, 3, 1, 1),
                nn.ReLU(),
                nn.Conv2d(32, 32, 3, 1, 1),
                nn.ReLU(),
                nn.MaxPool2d(2, 2),
                #nn.Dropout(0.2),
                nn.Conv2d(32, 64, 3, 1, 1),
                nn.ReLU(),
                nn.Conv2d(64, 64, 3, 1, 1),
                nn.ReLU(),
                nn.MaxPool2d(2, 2),
                #nn.Dropout(0.2),
                nn.Conv2d(64, 128, 3, 1, 1),
                nn.ReLU(),
                nn.Conv2d(128, 128, 3, 1, 1),
                nn.ReLU(),
                #nn.Dropout(0.2),
                nn.Flatten()
            )
            self.classifier = nn.Sequential(
                nn.Linear(128*8*8, 128),
                nn.ReLU(),
                #nn.Dropout(0.5),
                nn.Linear(128, 10)
            )
        elif combo == 2: # 2.1M
            self.features = nn.Sequential(
                nn.Conv2d(3, 32, 3, 1, 1),
                nn.ReLU(),
                nn.Conv2d(32, 32, 3, 1, 1),
                nn.ReLU(),
                nn.MaxPool2d(2, 2),
                #nn.Dropout(0.2),
                nn.Conv2d(32, 64, 3, 1, 1),
                nn.ReLU(),
                nn.Conv2d(64, 64, 3, 1, 1),
                nn.ReLU(),
                nn.MaxPool2d(2, 2),
                #nn.Dropout(0.2),
                nn.Conv2d(64, 128, 3, 1, 1),
                nn.ReLU(),
                nn.Conv2d(128, 128, 3, 1, 1),
                nn.ReLU(),
                nn.MaxPool2d(2, 2),
                #nn.Dropout(0.2),
                nn.Conv2d(128, 256, 3, 1, 1),
                nn.ReLU(),
                nn.Conv2d(256, 256, 3, 1, 1),
                nn.ReLU(),
                nn.Flatten()
            )
            self.classifier = nn.Sequential(
                nn.Linear(256*4*4, 256),
                nn.ReLU(),
                #nn.Dropout(0.5),
                nn.Linear(256, 10)
            )
        elif combo == 3:
            self.features = nn.Sequential(
                nn.Conv2d(3, 32, 3, 1, 1),
                nn.ReLU(),
                nn.Conv2d(32, 32, 3, 1, 1),
                nn.ReLU(),
                nn.MaxPool2d(2, 2),
                #nn.Dropout(0.2),
                nn.Conv2d(32, 64, 3, 1, 1),
                nn.ReLU(),
                nn.Conv2d(64, 64, 3, 1, 1),
                nn.ReLU(),
                nn.MaxPool2d(2, 2),
                #nn.Dropout(0.2),
                nn.Conv2d(64, 128, 3, 1, 1),
                nn.ReLU(),
                nn.Conv2d(128, 128, 3, 1, 1),
                nn.ReLU(),
                nn.MaxPool2d(2, 2),
                #nn.Dropout(0.2),
                nn.Conv2d(128, 256, 3, 1, 1),
                nn.ReLU(),
                nn.Conv2d(256, 256, 3, 1, 1),
                nn.ReLU(),
                nn.Conv2d(256, 256, 3, 1, 1),
                nn.ReLU(),
                nn.Conv2d(256, 256, 3, 1, 1),
                nn.ReLU(),
                nn.MaxPool2d(2, 2),
                nn.Flatten()
            )
            self.classifier = nn.Sequential(
                nn.Linear(256*2*2, 512),
                nn.ReLU(),
                nn.Linear(512, 128),
                nn.ReLU(),
                nn.Linear(128, 10)
            )
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()

        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            #nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            #nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.25),

            # Block 2
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            #nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            #nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.25),
            
            # Block 3
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            #nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            #nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            #nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.25),
            
            # Block 4
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            #nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            #nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            #nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.25),
            # Block 5
            # nn.Conv2d(512, 512, kernel_size=3, padding=1),
            # nn.ReLU(),
            # nn.Conv2d(512, 512, kernel_size=3, padding=1),
            # nn.ReLU(),
            # nn.Conv2d(512, 512, kernel_size=3, padding=1),
            # nn.ReLU(),
            # nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Flatten()
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(512*2*2, 4096),  # After 5 max pooling layers, the size is 32/(2^5) = 1
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 10)  # CIFAR-10 has 10 classes
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x
    
    

class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResNetBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        
    def forward(self, x):
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        if self.downsample is not None:
            identity = self.downsample(x)
        
        out += identity
        out = self.relu(out)
        
        return out


class ResNet8(nn.Module):
    def __init__(self, combo, num_classes=10):
        super(ResNet8, self).__init__()
        
        self.in_channels = 16
        
        self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        if combo == 4:
            self.layer1 = self.make_layer(ResNetBlock, 16, 1, stride=1)
            self.layer2 = self.make_layer(ResNetBlock, 32, 1, stride=2)
            self.layer3 = self.make_layer(ResNetBlock, 64, 1, stride=2)
        elif combo == 5:
            self.layer1 = self.make_layer(ResNetBlock, 16, 1, stride=1)
            self.layer2 = self.make_layer(ResNetBlock, 32, 2, stride=2)
            self.layer3 = self.make_layer(ResNetBlock, 64, 1, stride=2)
        elif combo == 6:
            self.layer1 = self.make_layer(ResNetBlock, 16, 1, stride=1)
            self.layer2 = self.make_layer(ResNetBlock, 32, 2, stride=2)
            self.layer3 = self.make_layer(ResNetBlock, 64, 2, stride=2)
        elif combo == 7:
            self.layer1 = self.make_layer(ResNetBlock, 16, 1, stride=1)
            self.layer2 = self.make_layer(ResNetBlock, 32, 1, stride=2)
            self.layer3 = self.make_layer(ResNetBlock, 64, 1, stride=2)
            self.layer4 = self.make_layer(ResNetBlock, 128, 1, stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, num_classes)
        
    def make_layer(self, block, out_channels, num_blocks, stride):
        layers = []
        layers.append(block(self.in_channels, out_channels, stride))
        self.in_channels = out_channels
        for _ in range(1, num_blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        if self.layer4 != None:
            x = self.layer4(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        return x

class ResBlock_ff(nn.Module):
    def __init__(self, input_sz, output_sz, in_channels, out_channels, initial_bl=False ,stride=1, extended = False, pool_en=0):
        super(ResBlock_ff, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.initial_bl = initial_bl
        self.pool_en = pool_en
        if initial_bl:
            self.norm0 = nn.LayerNorm([3,input_sz,input_sz], elementwise_affine=True)
            self.norm0.weight.data.fill_(delta)
            self.norm0.bias.data.fill_(mu)
            self.norm0.weight.requires_grad = False
            self.norm0.bias.requires_grad = False
            self.conv0 = nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.norm1 = nn.LayerNorm([self.in_channels,input_sz,input_sz], elementwise_affine=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.norm2 = nn.LayerNorm([self.out_channels,output_sz,output_sz], elementwise_affine=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
            )
        self.relu = nn.ReLU(inplace=True)
        
        self.extended = extended
        self.out_sz = output_sz
        if self.pool_en!=0:
            self.maxpool = nn.AdaptiveMaxPool2d((pool_en, pool_en))
            #self.maxpool = nn.AdaptiveAvgPool2d((pool_en, pool_en))
        if extended:
            exteneded_channels = 2*out_channels
            extended_sz = output_sz//2
            self.conv3 = nn.Conv2d(out_channels, exteneded_channels, kernel_size=3, stride=2, padding=1, bias=False)
            self.norm3 = nn.LayerNorm([exteneded_channels,extended_sz,extended_sz], elementwise_affine=True)
            self.conv4 = nn.Conv2d(exteneded_channels, exteneded_channels, kernel_size=3, stride=1, padding=1, bias=False)
            self.downsample2 = nn.Sequential(
                nn.Conv2d(out_channels, exteneded_channels, kernel_size=1, stride=2, bias=False),
            )
            
            self.norm3.weight.data.fill_(delta)
            self.norm3.bias.data.fill_(mu)
            self.norm3.weight.requires_grad = False
            self.norm3.bias.requires_grad = False
            
            if self.pool_en!=0:
                #extended_sz = extended_sz//self.pool_en
                extended_sz = pool_en
            self.lc = nn.Linear(10, extended_sz*extended_sz*exteneded_channels)
        else:
            if self.pool_en:
                #self.lc = nn.Linear(10, (output_sz//pool_en)*(output_sz//pool_en)*out_channels)
                self.lc = nn.Linear(10, pool_en*pool_en*out_channels)
            else:
                self.lc = nn.Linear(10, output_sz*output_sz*out_channels)
        
        self.norm1.weight.data.fill_(delta)
        self.norm1.bias.data.fill_(mu)
        self.norm2.weight.data.fill_(delta)
        self.norm2.bias.data.fill_(mu)
        self.norm1.weight.requires_grad = False
        self.norm1.bias.requires_grad = False
        self.norm2.weight.requires_grad = False
        self.norm2.bias.requires_grad = False
        
    def forward(self, x, y, opts):
        if self.initial_bl:
            x = self.norm0(x)
            x = self.conv0(x)
        x = self.norm1(x)
        identity = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.norm2(out)
        out = self.conv2(out)
        
        if self.extended:
            identity2 = out
            out2 = self.relu(out)
            out2 = self.norm2(out2)
            out2 = self.conv3(out2)
            out2 = self.relu(out2)
            out2 = self.norm3(out2)
            out2 = self.conv4(out2)
            out2 += self.downsample2(identity2)
            out2 = self.relu(out2)
            if self.pool_en:
                out22 = self.maxpool(out2)
                c = self.lc(y)
                c_reshaped = c.view(c.shape[0], out22.shape[1], out22.shape[2], out22.shape[3])
            else:
                out22 = out2
                c = self.lc(y)
                c_reshaped = c.view(c.shape[0], out22.shape[1], out22.shape[2], out22.shape[3])
            f = out22*c_reshaped
            self.h = f
            return f, out2
        else:
            if self.downsample is not None:
                identity = self.downsample(x)
            out += identity
            out = self.relu(out)
            if self.pool_en:
                out_t = self.maxpool(out)
                c = self.lc(y)
                c_reshaped = c.view(c.shape[0], out_t.shape[1], out_t.shape[2], out_t.shape[3])
            else:
                out_t = out
                c = self.lc(y)
                c_reshaped = c.view(c.shape[0], out_t.shape[1], out_t.shape[2], out_t.shape[3])  
            f = out_t*c_reshaped

            self.h = f
            return f, out
    
class ResBlock_ff_v2(nn.Module):
    def __init__(self, input_sz, output_sz, in_channels, out_channels, initial_bl=False ,stride=1, extended = False, pool_en=0):
        super(ResBlock_ff_v2, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.initial_bl = initial_bl
        self.pool_en = pool_en
        if initial_bl:
            self.norm0 = nn.LayerNorm([3,input_sz,input_sz], elementwise_affine=False)
            self.conv0 = nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.norm1 = nn.LayerNorm([self.in_channels,input_sz,input_sz], elementwise_affine=False)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.norm2 = nn.LayerNorm([self.out_channels,output_sz,output_sz], elementwise_affine=False)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
            )
        self.relu = nn.ReLU(inplace=True)
        
        self.extended = extended
        self.out_sz = output_sz
        if self.pool_en!=0:
            self.maxpool = nn.AdaptiveMaxPool2d((pool_en, pool_en))
        if extended:
            exteneded_channels = out_channels
            extended_sz = output_sz
            self.conv3 = nn.Conv2d(out_channels, exteneded_channels, kernel_size=3, stride=1, padding=1, bias=False)
            self.norm3 = nn.LayerNorm([exteneded_channels,extended_sz,extended_sz], elementwise_affine=False)
            self.conv4 = nn.Conv2d(exteneded_channels, exteneded_channels, kernel_size=3, stride=1, padding=1, bias=False)
            if self.pool_en!=0:
                extended_sz = self.pool_en
            self.lc = nn.Linear(10, extended_sz*extended_sz*exteneded_channels)
        else:
            if self.pool_en:
                self.lc = nn.Linear(10, pool_en*pool_en*out_channels)
            else:
                self.lc = nn.Linear(10, output_sz*output_sz*out_channels)
        
    def forward(self, x, y, opts):
        if self.initial_bl:
            x = self.norm0(x)
            x = self.conv0(x)
        x = self.norm1(x)
        identity = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.norm2(out)
        out = self.conv2(out)
        if self.extended:
            identity2 = out
            out2 = self.relu(out)
            out2 = self.norm2(out2)
            out2 = self.conv3(out2)
            out2 = self.relu(out2)
            out2 = self.norm3(out2)
            out2 = self.conv4(out2)
            out2 += identity2
            out2 = self.relu(out2)
            if self.pool_en:
                out22 = self.maxpool(out2)
                c = self.lc(y)
                c_reshaped = c.view(c.shape[0], out22.shape[1], out22.shape[2], out22.shape[3])
            else:
                out22 = out2
                c = self.lc(y)
                c_reshaped = c.view(c.shape[0], out22.shape[1], out22.shape[2], out22.shape[3])
            f = out22*c_reshaped
            self.h = f
            return f, out2
        else:
            if self.downsample is not None:
                identity = self.downsample(x)
            out += identity
            out = self.relu(out)
            if self.pool_en:
                out_t = self.maxpool(out)
                c = self.lc(y)
                c_reshaped = c.view(c.shape[0], out_t.shape[1], out_t.shape[2], out_t.shape[3])
            else:
                out_t = out
                c = self.lc(y)
                c_reshaped = c.view(c.shape[0], out_t.shape[1], out_t.shape[2], out_t.shape[3])   
            f = out_t*c_reshaped
            self.h = f
            return f, out

class FF_sequential(nn.Sequential):
    def forward(self, x, y, opts):
        for module in self._modules.values():
            _, x = module(x, y, opts)
        return x

class Resnet_ff(nn.Module):
    def __init__(self, combo):
        super(Resnet_ff, self).__init__()
        self.in_channels = 16
        if combo == 0:
            self.blocks = FF_sequential(
                ResBlock_ff(input_sz=32, output_sz=32, in_channels=16, out_channels=16, initial_bl=True, extended=True, pool_en=4),
                ResBlock_ff(input_sz=16, output_sz=8, in_channels=32, out_channels=64, stride=2 ,pool_en=2),
            )
        
        elif combo == 1:
            self.blocks = FF_sequential(
                ResBlock_ff(input_sz=32, output_sz=32, in_channels=16, out_channels=16, initial_bl=True, extended=True, pool_en=4), 
                ResBlock_ff(input_sz=16, output_sz=16, in_channels=32, out_channels=32, stride=1, pool_en=4),   
                ResBlock_ff(input_sz=16, output_sz=8, in_channels=32, out_channels=64, stride=2 ,pool_en=3),
            )
        
        elif combo == 2:
            self.blocks = FF_sequential(
                ResBlock_ff(input_sz=32, output_sz=32, in_channels=16, out_channels=16, initial_bl=True, extended=True, pool_en=4),
                ResBlock_ff(input_sz=16, output_sz=16, in_channels=32, out_channels=32, stride=1, pool_en=4),    
                ResBlock_ff(input_sz=16, output_sz=8, in_channels=32, out_channels=64, stride=2, pool_en=2),
                ResBlock_ff(input_sz=8, output_sz=8, in_channels=64, out_channels=64, stride=1, pool_en=2),
            )
    
    def forward(self, x, y, opts, diff_res):
        x = self.blocks(x, y ,opts)
        #can neglect this diff_res if it is only training. Just work as inference's weights on layers' goodness
        base = np.min(diff_res)
        hs = []
        for i, b in enumerate(diff_res):
            hs.append(self.blocks[i].h.view(self.blocks[i].h.shape[0],-1).pow(2).sum(1)*b/base)
        return torch.cat([t.unsqueeze(1) for t in hs], dim=1)