import os
import sys
import torch.nn.functional as F
import torch.nn as nn
import torch
sys.path.append(os.path.abspath("./"))
from HAWQ.utils.quantization_utils.quant_modules import QuantAct, QuantLinear, QuantConv2d, QuantBnConv2d, QuantAveragePool2d
# from pytorchcv.model_provider import get_model as ptcv_get_model

class Block(nn.Module):
    def __init__(self, in_dim, out_dim, normalize_input, ff, output_layer=False, bias=False):
        super().__init__()
        self.fc = nn.Linear(in_dim, out_dim, bias=bias)
        self.output_layer = output_layer
        self.relu = nn.ReLU(True)
        self.normalize_input = normalize_input
        self.ff = ff

    def forward(self, x):
        if self.normalize_input:
            x = F.normalize(x, dim=1)

        x = self.fc(x)
        self.x = x
        # self.z = self.relu(x)

        if self.ff:
            return self.relu(x)  # .detach()

        if self.output_layer:
            # return self.softmax(x)
            return self.x
        else:
            return self.relu(x)


class Network(nn.Module):
    def __init__(self, dims, ff=True, bias=False):
        super().__init__()

        blocks = []
        output_layer = False
        blocks.append(Block(dims[0], dims[1], False, ff, output_layer, bias))
        for i in range(len(dims[1:-1])):
            output_layer = (i == (len(dims)-3))

            blocks.append(Block(dims[i+1], dims[i+2],
                          True, ff, output_layer, bias))

        # just for print
        self.blocks = nn.Sequential(*blocks)
        self.n_blocks = len(blocks)
        self.ff = ff

    def forward(self, x, cat=True):
        x = self.blocks(x)

        if not self.ff:
            return x

        xs = [b.x for b in self.blocks.children()]

        if not cat:
            return xs
        return torch.stack(xs, dim=1)

##################### BP Modules #####################

#MLP
class Network_BP(nn.Module):
    def __init__(self, dims, bias=False):
        super().__init__()
        self.fc1 = nn.Linear(dims[0], dims[1], bias=bias)
        self.fc2 = nn.Linear(dims[1], dims[2], bias=bias)
        self.fc3 = nn.Linear(dims[2], dims[3], bias=bias)
        self.relu = nn.ReLU(True)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)

        x = F.normalize(x, dim=1)
        x = self.fc2(x)
        x = self.relu(x)

        x = F.normalize(x, dim=1)
        x = self.fc3(x)

        return x


class Network_BP_Q(nn.Module):
    def __init__(self, model,
                 weight_precision=8,
                 bias_precision=32,
                 act_precision=16):
        super().__init__()
        if model is None:
            raise ValueError('Model cannot be None')
        self.quant_input = QuantAct(act_precision)
        self.quant_act1 = QuantAct(act_precision)
        self.quant_act2 = QuantAct(act_precision)

        layer = getattr(model, 'relu')
        setattr(self, 'relu', layer)

        layer = getattr(model, 'fc1')
        quant_layer = QuantLinear(weight_precision, bias_bit=bias_precision)
        quant_layer.set_param(layer)
        setattr(self, 'fc1', quant_layer)

        layer = getattr(model, 'fc2')
        quant_layer = QuantLinear(weight_precision, bias_bit=bias_precision)
        quant_layer.set_param(layer)
        setattr(self, 'fc2', quant_layer)

        layer = getattr(model, 'fc3')
        quant_layer = QuantLinear(weight_precision, bias_bit=bias_precision)
        quant_layer.set_param(layer)
        setattr(self, 'fc3', quant_layer)

    def forward(self, x):
        # quantize input
        x, act_scaling_factor = self.quant_input(x)

        x = self.fc1(x, act_scaling_factor)
        x, act_scaling_factor = self.quant_act1(
            self.relu(x), act_scaling_factor)

        x = F.normalize(x, dim=1)
        x = self.fc2(x, act_scaling_factor)
        x, act_scaling_factor = self.quant_act2(
            self.relu(x), act_scaling_factor)

        x = F.normalize(x, dim=1)
        x = self.fc3(x, act_scaling_factor)

        return x


def mlp_bp(dims):
    return Network_BP(dims, bias=True)


def mlp_bp_q(model, dims, 
             weight_precision=8,
             bias_precision=32,
             act_precision=16):
    if model is None:
        model = mlp_bp(dims)
    return Network_BP_Q(model, weight_precision, bias_precision, act_precision)

#RESNET8
class RESNET(nn.Module):
    def __init__(self, pool, in_channels, out_channels, no_classes):
        super().__init__()
        
        # Input layer
        self.no_classes = no_classes # default class number for cifar10 is 10
        self.in_channels = in_channels # 3 channels
        self.out_channels = out_channels # this should be 64 for an official resnet model
        self.pool = pool
        self.conv1 = nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1, stride=1)
        self.bn1 = nn.BatchNorm2d(self.out_channels)
        self.relu = nn.ReLU(True)
        if self.pool:
            self.mp1 = nn.MaxPool2d(2) #enable for official resnet model
        
        # First stack      
        self.conv2 = nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=1)
        self.bn2 = nn.BatchNorm2d(self.out_channels)
        self.conv3 = nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=1)
        self.bn3 = nn.BatchNorm2d(self.out_channels)
        
        # Second stack      
        self.conv4 = nn.Conv2d(self.out_channels, 2*self.out_channels, kernel_size=3, padding=1, stride=2)
        self.bn4 = nn.BatchNorm2d(2*self.out_channels)
        self.conv5 = nn.Conv2d(2*self.out_channels, 2*self.out_channels, kernel_size=3, padding=1, stride=1)
        self.bn5 = nn.BatchNorm2d(2*self.out_channels)
        self.conv6 = nn.Conv2d(self.out_channels, 2*self.out_channels, kernel_size=3, padding=1, stride=2)
        
        # Third stack      
        self.conv7 = nn.Conv2d(2*self.out_channels, 4*self.out_channels, kernel_size=3, padding=1, stride=2)
        self.bn7 = nn.BatchNorm2d(4*self.out_channels)
        self.conv8 = nn.Conv2d(4*self.out_channels, 4*self.out_channels, kernel_size=3, padding=1, stride=1)
        self.bn8 = nn.BatchNorm2d(4*self.out_channels)
        self.conv9 = nn.Conv2d(2*self.out_channels, 4*self.out_channels, kernel_size=3, padding=1, stride=2)
        
        # Final classification layer
        self.ap1 = nn.AvgPool2d(4)
        self.flat = nn.Flatten()
        self.fc1 = nn.Linear(4*4*self.out_channels, self.no_classes)

    def forward(self, x):
        # Input layer, change kernel size to 7x7 and strides to 2 for an official resnet
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        if self.pool:
            x = self.mp1(x) #enable for official resnet model
        
        # First stack 
        out = self.conv2(x)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out = out + x
        x = self.relu(out)
        
        # Second stack 
        out = self.conv4(x)
        out = self.bn4(out)
        out = self.relu(out)
        out = self.conv5(out)
        out = self.bn5(out)
        x = self.conv6(x) # Adjust for change in dimension due to stride in identity
        out = out + x
        x = self.relu(out)
        
        # Third stack 
        out = self.conv7(x)
        out = self.bn7(out)
        out = self.relu(out)
        out = self.conv8(out)
        out = self.bn8(out)
        x = self.conv9(x) # Adjust for change in dimension due to stride in identity
        out = out + x
        x = self.relu(out)
        
        # Fourth stack.
        # While the paper uses four stacks, for cifar10 that leads to a large increase in complexity for minor benefits
        
        # Final classification layer
        out = self.ap1(x)
        out = self.flat(out)
        out = self.fc1(out)

        return out

class Q_ResNet8(nn.Module):
    """
        Quantized ResNet9 model from 'Deep Residual Learning for Image Recognition,' https://arxiv.org/abs/1512.03385.
    """
    def __init__(self, pool, in_channels, out_channels, no_classes, 
                 weight_precision=8,
                 bias_precision=32,
                 act_precision=16):
        super().__init__()

        self.no_classes = no_classes # default class number for cifar10 is 10
        self.in_channels = in_channels # 3 channels
        self.out_channels = out_channels # this should be 64 for an official resnet model
        self.pool = pool
        
        self.quant_input = QuantAct(activation_bit=act_precision)
        
        self.quant_init_block_convbn = QuantBnConv2d(weight_bit=weight_precision, bias_bit=bias_precision, per_channel=True)
        self.quant_init_block_convbn.set_param(nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1, stride=1), 
                                               nn.BatchNorm2d(self.out_channels)
                                               )
        if self.pool:
            self.mp1 = nn.MaxPool2d(2) #enable for official resnet model
            
        self.quant_act1 = QuantAct(activation_bit=act_precision)
        self.relu = nn.ReLU()

         # First stack
        self.quant_act2 = QuantAct(activation_bit=act_precision)
        self.convbn2 = QuantBnConv2d(weight_bit=weight_precision, bias_bit=bias_precision, per_channel=True)
        self.convbn2.set_param(nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=1), 
                               nn.BatchNorm2d(self.out_channels)
                               )      
        self.quant_act3 = QuantAct(activation_bit=act_precision)
        self.convbn3 = QuantBnConv2d(weight_bit=weight_precision, bias_bit=bias_precision, per_channel=True)
        self.convbn3.set_param(nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=1), 
                               nn.BatchNorm2d(self.out_channels)
                               ) 
        self.quant_act4 = QuantAct(activation_bit=act_precision)
        
        # Second stack      
        self.quant_act5 = QuantAct(activation_bit=act_precision)
        self.convbn4 = QuantBnConv2d(weight_bit=weight_precision, bias_bit=bias_precision, per_channel=True)
        self.convbn4.set_param(nn.Conv2d(self.out_channels, 2*self.out_channels, kernel_size=3, padding=1, stride=2), 
                               nn.BatchNorm2d(2*self.out_channels)
                               )      
        self.quant_act6 = QuantAct(activation_bit=act_precision)
        self.convbn5 = QuantBnConv2d(weight_bit=weight_precision, bias_bit=bias_precision, per_channel=True)
        self.convbn5.set_param(nn.Conv2d(2*self.out_channels, 2*self.out_channels, kernel_size=3, padding=1, stride=1), 
                               nn.BatchNorm2d(2*self.out_channels)
                               ) 
        self.conv6 = QuantConv2d(weight_bit=weight_precision, bias_bit=bias_precision)
        self.conv6.set_param(nn.Conv2d(self.out_channels, 2*self.out_channels, kernel_size=3, padding=1, stride=2)
                             )
        self.quant_act7 = QuantAct(activation_bit=act_precision)
        
        # Third stack   
        self.quant_act8 = QuantAct(activation_bit=act_precision)
        self.convbn7 = QuantBnConv2d(weight_bit=weight_precision, bias_bit=bias_precision, per_channel=True)
        self.convbn7.set_param(nn.Conv2d(2*self.out_channels, 4*self.out_channels, kernel_size=3, padding=1, stride=2), 
                               nn.BatchNorm2d(4*self.out_channels)
                               )    
        self.quant_act9 = QuantAct(activation_bit=act_precision)
        self.convbn8 = QuantBnConv2d(weight_bit=weight_precision, bias_bit=bias_precision, per_channel=True)
        self.convbn8.set_param(nn.Conv2d(4*self.out_channels, 4*self.out_channels, kernel_size=3, padding=1, stride=1), 
                               nn.BatchNorm2d(4*self.out_channels)
                               )    
        self.conv9 = QuantConv2d(weight_bit=weight_precision, bias_bit=bias_precision)
        self.conv9.set_param(nn.Conv2d(2*self.out_channels, 4*self.out_channels, kernel_size=3, padding=1, stride=2)
                             )
        self.quant_act10 = QuantAct(activation_bit=act_precision)
        
        # Final classification layer
        self.ap1 = QuantAveragePool2d(kernel_size=4, stride=4)
        self.flat = nn.Flatten()
        self.fc1 = QuantLinear(weight_bit=weight_precision, bias_bit=bias_precision)
        self.fc1.set_param(nn.Linear(4*4*self.out_channels, self.no_classes))

        self.quant_act_output = QuantAct(activation_bit=act_precision)



    def forward(self, x):
        x, act_scaling_factor = self.quant_input(x)

        x, weight_scaling_factor = self.quant_init_block_convbn(x, act_scaling_factor)

        if self.pool:
            x = self.pool(x)
        
        x, act_scaling_factor = self.quant_act1(x, act_scaling_factor, weight_scaling_factor)
        x = self.relu(x)

        # First stack 
        x_act_scaling_factor = act_scaling_factor.clone()
        x, act_scaling_factor = self.quant_act2(x, act_scaling_factor)
        identity = x
        x, weight_scaling_factor = self.convbn2(x, act_scaling_factor)
        x = self.relu(x)
        x, act_scaling_factor = self.quant_act3(x, act_scaling_factor, weight_scaling_factor)
        x, weight_scaling_factor = self.convbn3(x, act_scaling_factor)
        x = x + identity
        x, act_scaling_factor = self.quant_act4(x, act_scaling_factor, weight_scaling_factor, identity, x_act_scaling_factor, None)
        x = self.relu(x)
        
        # Second stack 
        x, act_scaling_factor = self.quant_act5(x, act_scaling_factor)
        x_act_scaling_factor = act_scaling_factor.clone()
        identity, x_weight_scaling_factor = self.conv6(x, x_act_scaling_factor) # Adjust for change in dimension due to stride in identity
        x, weight_scaling_factor = self.convbn4(x, act_scaling_factor)
        x = self.relu(x)
        x, act_scaling_factor = self.quant_act6(x, act_scaling_factor, weight_scaling_factor)
        x, weight_scaling_factor = self.convbn5(x, act_scaling_factor)
        x = x + identity
        x, act_scaling_factor = self.quant_act7(x, act_scaling_factor, weight_scaling_factor, identity, x_act_scaling_factor, x_weight_scaling_factor)
        x = self.relu(x)
        
        # Third stack 
        x, act_scaling_factor = self.quant_act8(x, act_scaling_factor)
        x_act_scaling_factor = act_scaling_factor.clone()
        identity, x_weight_scaling_factor = self.conv9(x, x_act_scaling_factor) # Adjust for change in dimension due to stride in identity
        x, weight_scaling_factor = self.convbn7(x, act_scaling_factor)
        x = self.relu(x)
        x, act_scaling_factor = self.quant_act9(x, act_scaling_factor, weight_scaling_factor)
        x, weight_scaling_factor = self.convbn8(x, act_scaling_factor)
        x = x + identity
        x, act_scaling_factor = self.quant_act10(x, act_scaling_factor, weight_scaling_factor, identity, x_act_scaling_factor, x_weight_scaling_factor)
        x = self.relu(x)
        
        # Fourth stack.
        # While the paper uses four stacks, for cifar10 that leads to a large increase in complexity for minor benefits
        

        x = self.ap1(x, act_scaling_factor)

        x, act_scaling_factor = self.quant_act_output(x, act_scaling_factor)
        # x = x.view(x.size(0), -1)
        x = self.flat(x)
        x = self.fc1(x, act_scaling_factor)

        return x



def q_resnet8(weight_precision=8,
             bias_precision=32,
             act_precision=16):
    return Q_ResNet8(pool=False, in_channels=3, out_channels=16, no_classes=10, weight_precision=weight_precision, bias_precision=bias_precision, act_precision=act_precision)
    
def resnet8():
    return RESNET(pool=False, in_channels=3, out_channels=16, no_classes=10)

class Q_ResNet18(nn.Module):
    """
        Quantized ResNet50 model from 'Deep Residual Learning for Image Recognition,' https://arxiv.org/abs/1512.03385.
    """
    def __init__(self, model):
        super().__init__()
        features = getattr(model, 'features')
        init_block = getattr(features, 'init_block')

        self.quant_input = QuantAct()

        self.quant_init_block_convbn = QuantBnConv2d()
        self.quant_init_block_convbn.set_param(init_block.conv.conv, init_block.conv.bn)

        self.quant_act_int32 = QuantAct()

        self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.act = nn.ReLU()

        self.channel = [2, 2, 2, 2]

        for stage_num in range(0, 4):
            stage = getattr(features, "stage{}".format(stage_num + 1))
            for unit_num in range(0, self.channel[stage_num]):
                unit = getattr(stage, "unit{}".format(unit_num + 1))
                quant_unit = Q_ResBlockBn()
                quant_unit.set_param(unit)
                setattr(self, f"stage{stage_num + 1}.unit{unit_num + 1}", quant_unit)

        self.final_pool = QuantAveragePool2d(kernel_size=7, stride=1)

        self.quant_act_output = QuantAct()

        output = getattr(model, 'output')
        self.quant_output = QuantLinear()
        self.quant_output.set_param(output)

    def forward(self, x):
        x, act_scaling_factor = self.quant_input(x)

        x, weight_scaling_factor = self.quant_init_block_convbn(x, act_scaling_factor)

        x = self.pool(x)
        x, act_scaling_factor = self.quant_act_int32(x, act_scaling_factor, weight_scaling_factor)

        x = self.act(x)

        for stage_num in range(0, 4):
            for unit_num in range(0, self.channel[stage_num]):
                tmp_func = getattr(self, f"stage{stage_num+1}.unit{unit_num+1}")
                x, act_scaling_factor = tmp_func(x, act_scaling_factor)

        x = self.final_pool(x, act_scaling_factor)

        x, act_scaling_factor = self.quant_act_output(x, act_scaling_factor)
        x = x.view(x.size(0), -1)
        x = self.quant_output(x, act_scaling_factor)

        return x

class Q_ResBlockBn(nn.Module):
    """
        Quantized ResNet block with residual path.
    """
    def __init__(self):
        super(Q_ResBlockBn, self).__init__()

    def set_param(self, unit):
        self.resize_identity = unit.resize_identity

        self.quant_act = QuantAct()

        convbn1 = unit.body.conv1
        self.quant_convbn1 = QuantBnConv2d()
        self.quant_convbn1.set_param(convbn1.conv, convbn1.bn)

        self.quant_act1 = QuantAct()

        convbn2 = unit.body.conv2
        self.quant_convbn2 = QuantBnConv2d()
        self.quant_convbn2.set_param(convbn2.conv, convbn2.bn)

        if self.resize_identity:
            self.quant_identity_convbn = QuantBnConv2d()
            self.quant_identity_convbn.set_param(unit.identity_conv.conv, unit.identity_conv.bn)

        self.quant_act_int32 = QuantAct()

    def forward(self, x, scaling_factor_int32=None):
        # forward using the quantized modules
        if self.resize_identity:
            x, act_scaling_factor = self.quant_act(x, scaling_factor_int32)
            identity_act_scaling_factor = act_scaling_factor.clone()
            identity, identity_weight_scaling_factor = self.quant_identity_convbn(x, act_scaling_factor)
        else:
            identity = x
            x, act_scaling_factor = self.quant_act(x, scaling_factor_int32)

        x, weight_scaling_factor = self.quant_convbn1(x, act_scaling_factor)
        x = nn.ReLU()(x)
        x, act_scaling_factor = self.quant_act1(x, act_scaling_factor, weight_scaling_factor)

        x, weight_scaling_factor = self.quant_convbn2(x, act_scaling_factor)

        x = x + identity

        if self.resize_identity:
            x, act_scaling_factor = self.quant_act_int32(x, act_scaling_factor, weight_scaling_factor, identity, identity_act_scaling_factor, identity_weight_scaling_factor)
        else:
            x, act_scaling_factor = self.quant_act_int32(x, act_scaling_factor, weight_scaling_factor, identity, scaling_factor_int32, None)

        x = nn.ReLU()(x)

        return x, act_scaling_factor

def q_resnet18(model):
    net = Q_ResNet18(model)
    return net
##################### SP Modules #####################

class MLP_Block_SP(nn.Module):
    def __init__(self, in_dim, out_dim, output_layer=False, bias=False):
        super().__init__()
        self.fc = nn.Linear(in_dim, out_dim, bias=bias)
        self.output_layer = output_layer
        self.relu = nn.ReLU(True)

    def forward(self, h, t, inference=False):
        h = self.fc(h)
        if not inference:
            t = self.fc(t)
            if self.output_layer:
                return h, t
            else:
                return self.relu(h), self.relu(t)
        else:
            if self.output_layer:
                return h
            else:
                return self.relu(h)


class Network_SP(nn.Module):
    def __init__(self, dim_x, dim_c, bias=False):
        super().__init__()

        blocks = []
        output_layer = False
        blocks.append(MLP_Input_Block_SP(dim_x[0], dim_c, dim_x[1], bias))
        for i in range(len(dim_x[1:-1])):
            output_layer = (i == (len(dim_x)-3))
            blocks.append(MLP_Block_SP(dim_x[i+1], dim_x[i+2], output_layer, bias))

        # just for print
        self.blocks = nn.Sequential(*blocks)
        # self.blocks = blocks
        self.n_blocks = len(blocks)

    def forward(self, x):

        for b in self.blocks.children():
            x = b(x, None, inference=True)

        return x

#RESNET
class RESNET_INPUT_SP(nn.Module):
    def __init__(self, in_channels, out_channels, no_classes):
        super().__init__()
        
        # Input layer
        self.no_classes = no_classes # default class number for cifar10 is 10
        self.in_channels = in_channels # 3 channels
        self.out_channels = out_channels # this should be 64 for an official resnet model
        # self.conv_h = nn.Conv2d(3, 128, kernel_size=3, padding=1, stride=1)
        # self.conv_h = nn.Conv2d(3, 3, kernel_size=3, padding=1, stride=1)
        # self.bn_h = nn.BatchNorm2d(128)
        # self.bn_h = nn.BatchNorm2d(3)
        self.fc_t = nn.Linear(self.no_classes*self.no_classes, 1*32*32)
        self.ln_t = nn.LayerNorm(1*32*32)
        # self.relu = nn.ReLU(True)
        self.relu = nn.LeakyReLU(True)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x, c, inference=False):
        # h = self.conv_h(x)
        # h = self.bn_h(h)
        # h = self.relu(h)
        h = x
        if not inference:
            # t = self.fc_t(c.flatten(1))
            # t = self.ln_t(t)
            t = self.fc_t(c.repeat(1,self.no_classes))
            t = self.ln_t(t).view(c.shape[0:1]+(1,32,32))
            t = t.repeat(1,3,1,1)
            t = self.relu(t)
            t = self.dropout(t)
            return h, t
        else:
            return h

class RESNET_Input_Block(nn.Module):
    def __init__(self, pool, in_channels, out_channels, no_classes):
        super().__init__()
        
        # Input layer
        self.no_classes = no_classes # default class number for cifar10 is 10
        self.in_channels = in_channels # 3 channels
        self.out_channels = out_channels # this should be 64 for an official resnet model
        self.pool = pool
        self.conv_h = nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1, stride=1)
        # self.conv_t = nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1, stride=1)
        self.bn1 = nn.BatchNorm2d(self.out_channels)
        # self.relu = nn.ReLU(True)
        self.relu = nn.LeakyReLU(True)
        self.dropout = nn.Dropout(0.1)
        if self.pool:
            self.mp1 = nn.MaxPool2d(2) #enable for official resnet model
        
    def forward(self, h, t, inference=False):
        h = self.conv_h(h)
        h = self.bn1(h)
        h = self.relu(h)
        h = self.dropout(h)
        if self.pool:
            h = self.mp1(h)
            
        self.h = torch.nn.functional.avg_pool2d(h,8).flatten(1)
            
        if not inference:
            t = self.conv_h(t)
            t = self.bn1(t)
            t = self.relu(t)
            t = self.dropout(t)
            if self.pool:
                t = self.mp1(t)
            return h, t #t.view(c.shape[0:1]+(1,32,32))
        else:
            return h
        
class RESNET_Block(nn.Module):
    def __init__(self, pool, in_channels, out_channels, kernel_size=3, padding=1, stride=1, normalize_input=False):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.pool = pool
        self.normalize_input = normalize_input
           
        self.conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size=self.kernel_size, padding=self.padding, stride=self.stride)
        self.bn = nn.BatchNorm2d(self.out_channels)
        # self.relu = nn.ReLU(True)
        self.relu = nn.LeakyReLU(True)
        self.dropout = nn.Dropout(0.1)
        if self.pool:
            self.mp1 = nn.MaxPool2d(2) #enable for official resnet model
            
    def forward(self, h, t, inference=False):
        if self.normalize_input:
            h = F.normalize(h,p=2,dim=1)
        h = self.conv(h)
        h = self.bn(h)
        h = self.relu(h) 
        h = self.dropout(h)
        self.h = torch.nn.functional.avg_pool2d(h,8).flatten(1)
        
        if not inference:
            t = self.conv(t)
            t = self.bn(t)
            t = self.relu(t)
            t = self.dropout(t)
            return h, t
        else:
            return h        

class RESNET_Block_Res(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1, adj=False, adj_in_channels=16, adj_stride=1, normalize_input=False):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.adjust = adj
        self.normalize_input = normalize_input
             
        self.conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size=self.kernel_size, padding=self.padding, stride=self.stride)
        self.bn = nn.BatchNorm2d(self.out_channels)
        # self.relu = nn.ReLU(True)
        self.relu = nn.LeakyReLU(True)
        self.dropout = nn.Dropout(0.1)
        if self.adjust:
            self.adj_stride = adj_stride
            self.adj_in_channels = adj_in_channels
            self.conv_adj_res = nn.Conv2d(self.adj_in_channels, self.out_channels, kernel_size=self.kernel_size, padding=self.padding, stride=self.adj_stride)

    def forward(self, h, res, t, inference=False):
        if self.normalize_input:
            h = F.normalize(h,p=2,dim=1)
            res = F.normalize(res,p=2,dim=1)
            
        h = self.conv(h)
        h = self.bn(h)
        if self.adjust:
            res = self.conv_adj_res(res)
        h = h + res
        h = self.relu(h)
        h = self.dropout(h)
        self.h = torch.nn.functional.avg_pool2d(h,8).flatten(1)
        
        if not inference:
            t = self.conv(t)
            t = self.bn(t)
            t = self.relu(t)
            t = self.dropout(t)
            return h, t
        else:
            return h           

class RESNET_Output_Block(nn.Module):
    def __init__(self, in_channels, out_channels, ap_dim, normalize_input=False):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.ap_dim = ap_dim
        self.normalize_input = normalize_input
        
        self.ap = nn.AvgPool2d(self.ap_dim)
        self.flat = nn.Flatten()
        self.fc = nn.Linear(self.in_channels, self.out_channels)
        
    def forward(self, x, c, inference=False):
        if self.normalize_input:
            x = F.normalize(x,p=2,dim=1)
        h = self.ap(x)
        h = self.flat(h)
        h = self.fc(h)
        self.h = h
        if not inference:
            t = self.ap(c)
            t = self.flat(t)
            t = self.fc(t)
            # t = self.relu(t)
            return h, t
        else:
            return h
        
class RESNET8(nn.Module):
    def __init__(self, pool, in_channels, out_channels, no_classes, nettype='bp'):
        super().__init__()
        self.pool = pool
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.no_classes = no_classes
        self.nettype = nettype
        
        blocks = []
        # SP Layer
        if nettype=='sp':
            blocks.append(RESNET_INPUT_SP(in_channels=self.in_channels, out_channels=self.out_channels, no_classes=self.no_classes))
            # Input Layer
            blocks.append(RESNET_Input_Block(pool=self.pool, in_channels=self.in_channels, out_channels=self.out_channels, no_classes=self.no_classes))
            # First stack 
            blocks.append(RESNET_Block(self.pool, self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=1))
            blocks.append(RESNET_Block_Res(self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=1, adj=False))
            # Second stack 
            blocks.append(RESNET_Block(self.pool, self.out_channels, 2*self.out_channels, kernel_size=3, padding=1, stride=2))
            blocks.append(RESNET_Block_Res(2*self.out_channels, 2*self.out_channels, kernel_size=3, padding=1, stride=1, adj=True, adj_in_channels = self.out_channels, adj_stride=2))
            # Third stack 
            blocks.append(RESNET_Block(self.pool, 2*self.out_channels, 4*self.out_channels, kernel_size=3, padding=1, stride=2))
            blocks.append(RESNET_Block_Res(4*self.out_channels, 4*self.out_channels, kernel_size=3, padding=1, stride=1, adj=True, adj_in_channels = 2*self.out_channels, adj_stride=2))
            # Output Layer
            blocks.append(RESNET_Output_Block(4*4*self.out_channels, self.no_classes, ap_dim=4))
        if nettype=='bp':
            # Input Layer
            blocks.append(RESNET_Input_Block(pool=self.pool, in_channels=self.in_channels, out_channels=self.out_channels, no_classes=self.no_classes))
            # First stack 
            blocks.append(RESNET_Block(self.pool, self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=1))
            blocks.append(RESNET_Block_Res(self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=1, adj=False))
            # Second stack 
            blocks.append(RESNET_Block(self.pool, self.out_channels, 2*self.out_channels, kernel_size=3, padding=1, stride=2))
            blocks.append(RESNET_Block_Res(2*self.out_channels, 2*self.out_channels, kernel_size=3, padding=1, stride=1, adj=True, adj_in_channels = self.out_channels, adj_stride=2))
            # Third stack 
            blocks.append(RESNET_Block(self.pool, 2*self.out_channels, 4*self.out_channels, kernel_size=3, padding=1, stride=2))
            blocks.append(RESNET_Block_Res(4*self.out_channels, 4*self.out_channels, kernel_size=3, padding=1, stride=1, adj=True, adj_in_channels = 2*self.out_channels, adj_stride=2))
            # Output Layer
            blocks.append(RESNET_Output_Block(4*4*self.out_channels, self.no_classes, ap_dim=4))
        if nettype=='ff':
            # Input Layer
            blocks.append(RESNET_Input_Block(pool=self.pool, in_channels=self.in_channels, out_channels=self.out_channels, no_classes=self.no_classes))
            # First stack 
            blocks.append(RESNET_Block(self.pool, self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=1, normalize_input=True))
            blocks.append(RESNET_Block_Res(self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=1, adj=False, normalize_input=True))
            # Second stack 
            blocks.append(RESNET_Block(self.pool, self.out_channels, 2*self.out_channels, kernel_size=3, padding=1, stride=2))
            blocks.append(RESNET_Block_Res(2*self.out_channels, 2*self.out_channels, kernel_size=3, padding=1, stride=1, adj=True, adj_in_channels = self.out_channels, adj_stride=2, normalize_input=True))
            # Third stack 
            blocks.append(RESNET_Block(self.pool, 2*self.out_channels, 4*self.out_channels, kernel_size=3, padding=1, stride=2))
            blocks.append(RESNET_Block_Res(4*self.out_channels, 4*self.out_channels, kernel_size=3, padding=1, stride=1, adj=True, adj_in_channels = 2*self.out_channels, adj_stride=2, normalize_input=True))
            # Output Layer
            blocks.append(RESNET_Output_Block(4*4*self.out_channels, self.no_classes, ap_dim=4, normalize_input=True))
            
        # just for print
        self.blocks = nn.Sequential(*blocks)
        # self.blocks = blocks
        self.n_blocks = len(blocks)

    def forward(self, x):
        
        if self.nettype=='sp':
            x = self.blocks[0](x,None,inference=True) 
            x = self.blocks[1](x,None,inference=True)            
            h = self.blocks[2](x,None,inference=True)            
            x = self.blocks[3](h,x,None,inference=True)            
            h = self.blocks[4](x,None,inference=True)            
            x = self.blocks[5](h,x,None,inference=True)           
            h = self.blocks[6](x,None,inference=True)            
            x = self.blocks[7](h,x,None,inference=True)           
            x = self.blocks[8](x,None,inference=True)                   
        else:
            x = self.blocks[0](x,None,inference=True)            
            h = self.blocks[1](x,None,inference=True)            
            x = self.blocks[2](h,x,None,inference=True)            
            h = self.blocks[3](x,None,inference=True)            
            x = self.blocks[4](h,x,None,inference=True)           
            h = self.blocks[5](x,None,inference=True)            
            x = self.blocks[6](h,x,None,inference=True)           
            x = self.blocks[7](x,None,inference=True)           
        
        if self.nettype=='ff':
            # return the 'goodness' of each layer
            # hs = [self.ap_out(b.h).flatten(1).norm(dim=-1) for b in list(self.blocks.children())[:-1]]
            # return torch.stack(hs,dim=1)
        
            # return the activations of each layer (averagepool and flatten the outputs of convolutions)
            hs = [b.h for b in list(self.blocks.children())[:-1]]
            return hs
        
        return x        
    
def resnet8_v1(nettype='bp'):
    return RESNET8(pool=False, in_channels=3, out_channels=16, no_classes=10, nettype=nettype)


#MLP
class MLP_Input_Block_SP(nn.Module):
    def __init__(self, in_dim_h, in_dim_t, out_dim, bias=False):
        super().__init__()
        self.fc_h = nn.Linear(in_dim_h, out_dim, bias=bias)
        self.fc_t = nn.Linear(in_dim_t, out_dim, bias=bias)
        self.ln_t = nn.LayerNorm(out_dim)
        self.relu = nn.ReLU(True)
        self.output_layer = False

    def forward(self, x, c, inference=False):
        h = self.fc_h(x)
        h = self.relu(h)
        if not inference:
            t = self.fc_t(c)
            t = self.ln_t(t)
            t = self.relu(t)
            return h, t
        else:
            return h
        
class MLP_Block(nn.Module):
    def __init__(self, in_dim, out_dim, bias=False, normalize_input=False):
        super().__init__()
        self.fc = nn.Linear(in_dim, out_dim, bias=bias)
        self.relu = nn.ReLU(True)
        self.norm = nn.LayerNorm(in_dim, eps=1e-9, elementwise_affine=False)
        self.normalize_input = normalize_input

    def forward(self, h, t, inference=False):
        if self.normalize_input:
            # h = F.normalize(h, dim=1)
            h = self.norm(h)
            # means = h.mean(dim=1, keepdim=True)
            # stds = h.std(dim=1, keepdim=True)
            # h = (h - means) / stds
        h = self.fc(h)
        h = self.relu(h)
        self.h = h
        if not inference:
            t = self.fc(t)
            t = self.relu(t)
            return h, t
        else:
            return h
        
class MLP_Block_tweaked(nn.Module):
    def __init__(self, in_dim, out_dim, bias=False, normalize_input=False):
        super().__init__()
        self.fc = nn.Linear(in_dim, out_dim, bias=bias)
        self.relu = nn.ReLU(True)
        self.normalize_input = normalize_input

    def forward(self, h, t, inference=False):
        if self.normalize_input:
            h = F.normalize(h, dim=1)
        h = self.fc(h)
        h = self.relu(h)
        self.h = h
        if not inference:
            t = self.fc(t)
            t = self.relu(t)
            return h, t
        else:
            return h        

class MLP_Output_Block(nn.Module):
    def __init__(self, in_dim, out_dim, bias=False, normalize_input=False):
        super().__init__()
        self.fc = nn.Linear(in_dim, out_dim, bias=bias)
        self.relu = nn.ReLU(True)
        self.output_layer = True
        self.norm = nn.LayerNorm(in_dim, eps=1e-9, elementwise_affine=False)
        self.normalize_input = normalize_input
        
    def forward(self, h, t, inference=False):
        if self.normalize_input:
            # h = F.normalize(h, dim=1)
            h = self.norm(h)
        h = self.fc(h)
        self.h = h
        # h = self.relu(h)
        
        if not inference:
            t = self.fc(t)
            t = self.relu(t)
            return h, t
        else:
            return h

class MLP_Net(nn.Module):
    def __init__(self, nettype='bp', bias=False):
        super().__init__()

        self.nettype = nettype
        blocks = []
        if nettype=='sp':
            blocks.append(MLP_Input_Block_SP(in_dim_h=120, in_dim_t=8, out_dim=8, bias=True))
            blocks.append(MLP_Block(in_dim = 8, out_dim = 8, bias=True))
            blocks.append(MLP_Block(in_dim = 8, out_dim = 8, bias=True))
            blocks.append(MLP_Output_Block(in_dim = 8, out_dim = 8, bias=True))
        if nettype=='ff':
            blocks.append(MLP_Block(in_dim = 128, out_dim = 8, bias=True, normalize_input=False))
            blocks.append(MLP_Block(in_dim = 8, out_dim = 8, bias=True, normalize_input=True))
            blocks.append(MLP_Block(in_dim = 8, out_dim = 8, bias=True, normalize_input=True))
        if nettype=='bp':
            blocks.append(MLP_Block(in_dim = 120, out_dim = 8, bias=True))
            blocks.append(MLP_Block(in_dim = 8, out_dim = 8, bias=True))
            blocks.append(MLP_Output_Block(in_dim = 8, out_dim = 8, bias=True))
        # just for print
        self.blocks = nn.Sequential(*blocks)
        self.n_blocks = len(blocks)

    def forward(self, x):
        if self.nettype=='sp':
            x = self.blocks[0](x,None,inference=True) 
            x = self.blocks[1](x,None,inference=True)            
            x = self.blocks[2](x,None,inference=True)            
            x = self.blocks[3](x,None,inference=True)                              
        else:
            x = self.blocks[0](x,None,inference=True) 
            x = self.blocks[1](x,None,inference=True)            
            x = self.blocks[2](x,None,inference=True)          
        
        if self.nettype=='ff':
            hs = [b.h.view(b.h.shape[0],-1) for b in self.blocks.children()]
            return torch.stack(hs,dim=1)
        
        return x        

# QUANTIZED MLP
class MLP_Input_Block_SP_Q(nn.Module):
    def __init__(self,
                 weight_precision=8,
                 bias_precision=32,
                 act_precision=16
                 ):
        super().__init__()
        
        self.quant = QuantAct(activation_bit=act_precision)
        self.fc_h = QuantLinear(weight_bit=weight_precision, bias_bit=bias_precision)
        self.fc_t = QuantLinear(weight_bit=weight_precision, bias_bit=bias_precision)
        
    def set_param(self, layer):
        self.fc_h.set_param(layer.fc_h)
        self.fc_t.set_param(layer.fc_t)
        # self.ln_t = layer.ln_t
        self.relu = layer.relu
        self.output_layer = layer.output_layer

    def forward(self, x, c, inference=False):
        x, act_scaling_factor_1 = self.quant(x)
        
        h = self.fc_h(x, act_scaling_factor_1)
        h, act_scaling_factor_1 = self.quant(self.relu(h), act_scaling_factor_1)
        
        if not inference:
            c, act_scaling_factor_2 = self.quant(c)
            t = self.fc_t(c, act_scaling_factor_2)
            # t = self.ln_t(t)
            t, act_scaling_factor_2 = self.quant(self.relu(t), act_scaling_factor_2)
            return h, t
        else:
            return h

class MLP_Block_Q(nn.Module):
    def __init__(self, 
                 weight_precision=8,
                 bias_precision=32,
                 act_precision=16):
        super().__init__()
        
        self.quant = QuantAct(activation_bit=act_precision)
        self.fc = QuantLinear(weight_bit=weight_precision, bias_bit=bias_precision)
        
        
    def set_param(self, layer):
        self.fc.set_param(layer.fc)
        self.relu = layer.relu
        self.normalize_input = layer.normalize_input
        
    def forward(self, h, t, inference=False):
        h, act_scaling_factor_1 = self.quant(h)
        
        if self.normalize_input:
            h = F.normalize(h, dim=1)
            
        h = self.fc(h, act_scaling_factor_1)
        h, act_scaling_factor_1 = self.quant(self.relu(h), act_scaling_factor_1)
        self.h = h
        
        if not inference:
            t, act_scaling_factor_2 = self.quant(t)
            t = self.fc(t, act_scaling_factor_2)
            t, act_scaling_factor_2 = self.quant(self.relu(t), act_scaling_factor_2)
            return h, t
        else:
            return h

class MLP_Output_Block_Q(nn.Module):
    def __init__(self, 
                 weight_precision=8,
                 bias_precision=32,
                 act_precision=16):
        super().__init__()
        self.quant = QuantAct(activation_bit=act_precision)
        self.fc = QuantLinear(weight_bit=weight_precision, bias_bit=bias_precision)
    
    def set_param(self, layer):
        self.fc.set_param(layer.fc)
        self.relu = layer.relu
        self.normalize_input = layer.normalize_input
        
    def forward(self, h, t, inference=False):
        h, act_scaling_factor_1 = self.quant(h)
    
        if self.normalize_input:
            h = F.normalize(h, dim=1)
            
        h = self.fc(h, act_scaling_factor_1)
        h, act_scaling_factor_1 = self.quant(h, act_scaling_factor_1) # no relu
        self.h = h
        
        if not inference:
            t, act_scaling_factor_2 = self.quant(t)
            t = self.fc(t, act_scaling_factor_2)
            t, act_scaling_factor_2 = self.quant(t, act_scaling_factor_2) # no relu
            return h, t
        else:
            return h

class MLP_Net_Q(nn.Module):
    def __init__(self, model,
                 weight_precision=8,
                 bias_precision=32,
                 act_precision=16):
        
        super().__init__()
        if model is None:
            raise ValueError('Model cannot be None')
        self.nettype = model.nettype
        
        layers = [getattr(model.blocks,str(i)) for i in range(0,len(model.blocks))]
        blocks = []
        for layer in layers:
            if layer._get_name()=='MLP_Input_Block_SP':
                l = MLP_Input_Block_SP_Q(weight_precision, bias_precision, act_precision)
                l.set_param(layer)
                blocks.append(l)
            elif layer._get_name()=='MLP_Block':
                l = MLP_Block_Q(weight_precision, bias_precision, act_precision)
                l.set_param(layer)
                blocks.append(l)
            elif layer._get_name()=='MLP_Output_Block':
                l = MLP_Output_Block_Q(weight_precision, bias_precision, act_precision)
                l.set_param(layer)
                blocks.append(l)
            else:
                ModuleNotFoundError()
        self.blocks = nn.Sequential(*blocks)
        self.n_blocks = len(blocks)
        
    def forward(self, x):
        if self.nettype=='sp':
            x = self.blocks[0](x,None,inference=True) 
            x = self.blocks[1](x,None,inference=True)            
            x = self.blocks[2](x,None,inference=True)            
            x = self.blocks[3](x,None,inference=True)                              
        else:
            x = self.blocks[0](x,None,inference=True) 
            x = self.blocks[1](x,None,inference=True)            
            x = self.blocks[2](x,None,inference=True)          
        
        if self.nettype=='ff':
            hs = [b.h.view(b.h.shape[0],-1) for b in self.blocks.children()]
            return torch.stack(hs,dim=1)
        
        return x
    
def mlp_v1(nettype='bp'):
    return MLP_Net(nettype=nettype, bias=True)

def mlp_v1_q(model, nettype,
             weight_precision=8,
             bias_precision=32,
             act_precision=16):
    if model is None:
        model = mlp_v1(nettype)
    return MLP_Net_Q(model, weight_precision, bias_precision, act_precision)