from faulthandler import is_enabled
from turtle import forward
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from collections import OrderedDict

def ew(weight, t=2):
    exp = torch.exp(torch.abs(weight)*t)
    # return exp/exp.max(dim=0, keepdim=True)[0].max(dim=1, keepdim=True)[0]*weight
    return exp/exp.max()*weight
    # return exp/exp.max(dim=0, keepdim=True)[0] * weight
    # return exp/exp.max(dim=1, keepdim=True)[0] * weight

ENABLE_EW = False

class EWLinear(nn.Linear):
    def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
        super().__init__(in_features, out_features, bias)
        self.ew_enable = ENABLE_EW

    def enable_ew(self, enable=True):
        self.ew_enable = enable

    def forward(self, input):
        if self.ew_enable:
            return F.linear(input, ew(self.weight), ew(self.bias))
        else:
            return F.linear(input, self.weight, self.bias)
    
    def expize_weight(self):
        with torch.no_grad():
            self.weight.data = ew(self.weight)
            self.bias.data = ew(self.bias)

class EWConv2d(nn.Conv2d):
    def __init__(self, in_channels: int, out_channels: int, kernel_size, stride = 1, padding = 0, dilation= 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros'):
        super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode)
        self.ew_enable = ENABLE_EW
    
    def enable_ew(self, enable=True):
        self.ew_enable = enable
    
    def forward(self, input):
        if self.ew_enable:
            return self._conv_forward(input, ew(self.weight))
        else:
            return self._conv_forward(input, self.weight)

    def expize_weight(self):
        with torch.no_grad():
            self.weight.data = ew(self.weight)


class EWBatchNorm2d(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(EWBatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)
        self.ew_enable = ENABLE_EW

    def forward(self, input):
        self._check_input_dim(input)

        # exponential_average_factor is set to self.momentum
        # (when it is available) only so that it gets updated
        # in ONNX graph when this node is exported to ONNX.
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            # TODO: if statement only here to tell the jit to skip emitting this when it is None
            if self.num_batches_tracked is not None:
                self.num_batches_tracked = self.num_batches_tracked + 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        r"""
        Decide whether the mini-batch stats should be used for normalization rather than the buffers.
        Mini-batch stats are used in training mode, and in eval mode when buffers are None.
        """
        if self.training:
            bn_training = True
        else:
            bn_training = (self.running_mean is None) and (self.running_var is None)

        r"""
        Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
        passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
        used for normalization (i.e. in eval mode when buffers are not None).
        """
        if self.ew_enable:
            weight, bias = ew(self.weight), ew(self.bias)
        else:
            weight, bias = self.weight, self.bias

        return F.batch_norm(
            input,
            # If buffers are not to be tracked, ensure that they won't be updated
            self.running_mean if not self.training or self.track_running_stats else None,
            self.running_var if not self.training or self.track_running_stats else None,
            weight, bias, bn_training, exponential_average_factor, self.eps)

    def enable_ew(self, enable=True):
        self.ew_enable = enable
        
    def expize_weight(self):
        with torch.no_grad():
            self.weight.data = ew(self.weight)
            self.bias.data = ew(self.bias)

def conv3x3(in_planes, out_planes, stride=1):
    return EWConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class EWBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(EWBasicBlock, self).__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(OrderedDict([
                ('conv', EWConv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)),
                ('bn', nn.BatchNorm2d(self.expansion*planes))
                ]))

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class EWPreActBlock(nn.Module):
    '''Pre-activation version of the EWBasicBlock.'''
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(EWPreActBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(OrderedDict([
                ('conv', EWConv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False))
            ]))

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out)
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        out += shortcut
        return out


class EWBottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(EWBottleneck, self).__init__()
        self.conv1 = EWConv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = EWConv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = EWConv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(OrderedDict([
                ('conv',EWConv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)),
                ('bn',nn.BatchNorm2d(self.expansion*planes))
            ]))

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class EWPreActBottleneck(nn.Module):
    '''Pre-activation version of the original EWBottleneck module.'''
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(EWPreActBottleneck, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = EWConv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = EWConv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes)
        self.conv3 = EWConv2d(planes, self.expansion*planes, kernel_size=1, bias=False)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(OrderedDict([
                ('conv', EWConv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False))
            ]))

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out)
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        out = self.conv3(F.relu(self.bn3(out)))
        out += shortcut
        return out



class EWResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(EWResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = conv3x3(3,64)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        # self.linear = nn.Linear(512*block.expansion, num_classes
        # 
        # )
        self.linear = EWLinear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, mask=None, lin=0, lout=6):
        out = x
        if lin < 1 and lout > -1:
            out = self.conv1(out)
            out = self.bn1(out)
            out = F.relu(out)
        if lin < 2 and lout > 0:
            out = self.layer1(out)
        if lin < 3 and lout > 1:
            out = self.layer2(out)
        if lin < 4 and lout > 2:
            out = self.layer3(out)
        if lin < 5 and lout > 3:
            out = self.layer4(out)
        # if lout > 4:
        #     out = F.avg_pool2d(out, 4)
        #     out = out.view(out.size(0), -1)
        #     out = self.linear(out)
        if lin < 6 and lout > 4:
            out = F.avg_pool2d(out, 4)
            out = out.view(out.size(0), -1)
        if lout > 5:
            out = self.linear(out)
        return out

class EWCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = EWConv2d(3, 16, 3, 2, 1) # 16
        self.conv2 = EWConv2d(16, 32, 3, 2, 1) # 8
        self.conv3 = EWConv2d(32, 64, 3, 2, 1) # 4
        self.conv4 = EWConv2d(64, 128, 3, 2, 1) # 2
        self.linear = EWLinear(512, num_classes)    
    
    def forward(self, x, mask=None):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def EWResNet18(num_classes=10):
    return EWResNet(EWBasicBlock, [2,2,2,2], num_classes)

def EWResNet34(num_classes):
    return EWResNet(EWBasicBlock, [3,4,6,3], num_classes)

def EWResNet50(num_classes):
    return EWResNet(EWBottleneck, [3,4,6,3], num_classes)

def EWResNet101(num_classes):
    return EWResNet(EWBottleneck, [3,4,23,3], num_classes)

def EWResNet152(num_classes):
    return EWResNet(EWBottleneck, [3,8,36,3], num_classes)



# test()
