import torch
import torch.nn as nn
import os
import pdb
from torch.nn import functional as F
import copy
from utils.conf import get_device

def APDConv2d(x, key_weights, task_id, stride, padding, module_key, info, l1_hyp, device):
    key = module_key + '.' + info
    tsh = get_key_weights(key_weights, key, 'tsh', r18_dims[key]).to(device)
    msk = get_key_weights(key_weights, key, 'msk_t%d'%task_id, [r18_dims[key][0]]).to(device)
    tad = get_key_weights(key_weights, key, 'tad_t%d'%task_id, r18_dims[key], is_tad=True).to(device)
    return F.conv2d(x, torch.einsum('i, i...->i...', msk, tsh) + l1_pruning(tad, l1_hyp), stride=stride, padding=padding)

def l1_pruning(input, l1_hyp = 0):
    if l1_hyp == 0:
        return input
    else:
        hard_threshold = torch.abs(input) > l1_hyp
        return input * hard_threshold

def get_key_weights(key_weights, w_info, w_type, shape, is_tad=False):
    key = '%s.%s'%(w_info, w_type)
    if key in key_weights.keys():
        return key_weights[key]
    else:
        if not is_tad:
            key_weights[key] = torch.ones(shape, requires_grad=True)
            # initialization of mask: ones, otherwise: kaiming_normal
            nn.init.kaiming_normal_(key_weights[key].data, mode='fan_out', nonlinearity='relu') if len(shape)>2 else None
        else:
            tsh_key = '%s.%s'%(w_info, 'tsh')
            key_weights[key] = copy.deepcopy(key_weights[tsh_key])
        return key_weights[key]


class BasicBlock(nn.Module):
    expansion = 1
    #def __init__(self, in_planes, planes, l1_hyp, stride=1):
    """
    def __init__(self, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None, l1_hyp=0, module_key=''):
    """
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None, is_shortcut=False, l1_hyp=0, module_key=''):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

        if is_shortcut:
            if stride != 1 or inplanes != 1*planes:
                self.bn3 = nn.BatchNorm2d(1*planes)
                self.shortcut = nn.Sequential(
                    nn.Conv2d(inplanes, 1*planes,
                              kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(1*planes)
                )
        """
        if downsample:
            self.bn3 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
        """

    def forward(self, x, task_id, is_eval):
        identity = x
        c1_info = self.module_key + '.conv1.weights'
        tsh = get_key_weights(self.key_weights, c1_info, 'tsh', r18_dims[c1_info]).to(self.device)
        msk = get_key_weights(self.key_weights, c1_info, 'msk_t%d'%task_id, [r18_dims[c1_info][0]]).to(self.device)
        tad = get_key_weights(self.key_weights, c1_info, 'tad_t%d'%task_id, r18_dims[c1_info], is_tad=True).to(self.device)
        out = F.relu(self.bn1(F.conv2d(x, torch.einsum('i, i...->i...', msk, tsh) + l1_pruning(self.l1_hyp, tad, is_eval=is_eval), stride=1, padding=1)))

        c2_info = w_info + '.conv2.weights'
        tsh = get_key_weights(self.key_weights, c2_info, 'tsh', r18_dims[c2_info]).to(self.device)
        msk = get_key_weights(self.key_weights, c2_info, 'msk_t%d'%task_id, [r18_dims[c2_info][0]]).to(self.device)
        tad = get_key_weights(self.key_weights, c2_info, 'tad_t%d'%task_id, r18_dims[c2_info], is_tad=True).to(self.device)
        out = bn2(F.conv2d(out, torch.einsum('i, i...->i...', msk, tsh) + l1_pruning(self.l1_hyp, tad, is_eval=is_eval), stride=1, padding=1))

        """
        if self.downsample is not None:
            dw_info = w_info + '.downsample.weights'
            tsh = self.get_key_weights(dw_info, 'tsh', r18_dims[dw_info]).to(self.device)
            msk = self.get_key_weights(dw_info, 'msk_t%d'%task_id, [r18_dims[dw_info][0]]).to(self.device)
            tad = self.get_key_weights(dw_info, 'tad_t%d'%task_id, r18_dims[dw_info], is_tad=True).to(self.device)
            identity = self.bn3(F.conv2d(x, torch.einsum('i, i...->i...', msk, tsh) + self.l1_pruning(tad, is_eval=is_eval), stride=1, padding=1))
        """
        out += identity
        return self.relu(out)

class ResNet(nn.Module):
    #def __init__(self, block, num_blocks, device, num_classes=100, l1_hyp=0):
    def __init__(self, block, layers, device, num_classes=100, tad_thr=0, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        self.n_classes = num_classes
        self.device = device
        self.key_weights = {}
        self.l1_hyp = l1_hyp

        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group

        ## CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1
        #self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        ## END

        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, self.n_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)


        """
        self.tsh_list = nn.ParameterList()
        self.msk_list = nn.ParameterList()
        self.tad_list = nn.ParameterList()
        """
        #self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        """
        def __init__(self, stride=1, downsample=None, groups=1,
                     base_width=64, dilation=1, norm_layer=None, l1_hyp=0, module_key=''):
        """
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])

        """
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, l1_hyp=l1_hyp)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, l1_hyp=l1_hyp)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, l1_hyp=l1_hyp)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, l1_hyp=l1_hyp)
        """
        #self.fc = nn.Linear(512*block.expansion, num_classes)
        self.fc = nn.Linear(512 * block.expansion, self.n_classes)

    #   _make_layer(     out, task_id, 'layer1', 64, self.num_blocks[0], stride=1)
    def _make_layer(self, x, task_id, w_info, planes, num_blocks, stride):
        """
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
        """

        strides = [stride] + [1]*(num_blocks-1)
        for stride in strides:
            x = block(x, w_info, task_id, is_eval)
            #x = block(self.in_planes, planes, stride)
            self.in_planes = planes * block.expansion

        return x

    """
    def out_forward(self, x, task_id, l1_hyp):
        out = F.relu(self.bn1(APDConv2d(x, self.key_weights, task_id, 1, 1, 'conv1', 'weights', self.l1_hyp, self.device)))
        pdb.set_trace()
        out = self.layer1(out, task_id, 'layer1', 64, self.num_blocks[0], l1_hyp, stride=1)
        out = self.layer2(out, task_id, 'layer2', 128, self.num_blocks[1], l1_hyp, stride=2)
        out = self.layer3(out, task_id, 'layer3', 256, self.num_blocks[2], l1_hyp, stride=2)
        out = self.layer4(out, task_id, 'layer4', 512, self.num_blocks[3], l1_hyp, stride=2)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out
    """

    def embed(self, x):
        x1 = F.relu(self.bn1(APDConv2d(x, self.key_weights, task_id, 1, 1, 'conv1', 'weights', self.l1_hyp, self.device)))
        x2 = self.layer1(x1)
        x3 = self.layer2(x2)
        x4 = self.layer3(x3)
        x5 = self.layer4(x4)
        x = self.avgpool(x5)
        x = x.reshape(x.size(0), -1)
        return x, [x1, x2, x3, x4, x5]

    def forward(self, x, return_features=False, plot_metrics=False):
        x, features_list = self.embed(x)
        if return_features:
            return x
        else:
            x = self.fc(x)

        if plot_metrics:
            return x, features_list
        return x

    def get_params(self):
        params = []
        for pp in list(self.parameters()):
          # if pp.grad is not None:
          params.append(pp.view(-1))
        return torch.cat(params)

    def get_grads(self):
        grads = []
        for pp in list(self.parameters()):
            # if pp.grad is not None:
            grads.append(pp.grad.view(-1))
        return torch.cat(grads)


def ResNet18(device='cpu', **kwargs):
    return ResNet(BasicBlock, [2,2,2,2], device, **kwargs)

def ResNet34(device='cpu', **kwargs):
    return ResNet(BasicBlock, [3,4,6,3], device, **kwargs)

def ResNet50(device='cpu', **kwargs):
    return ResNet(Bottleneck, [3,4,6,3], device, **kwargs)

def ResNet101(device='cpu', **kwargs):
    return ResNet(Bottleneck, [3,4,23,3], device, **kwargs)

def ResNet152(device='cpu', **kwargs):
    return ResNet(Bottleneck, [3,8,36,3], device, **kwargs)


def test():
    net = ResNet18()
    y = net(torch.randn(1,3,32,32))
    print(y.size())



"""
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None, module_key=''):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.module_key = module_key
        self.bn1 = nn.BatchNorm2d(planes)
        self.bn2 = nn.BatchNorm2d(planes)

    #def forward(self, x, weights, keylists):
    def forward(self, inputs):
        module_key = self.module_key
        x = inputs[0]
        self.weights = inputs[1]
        identity = x
        out = FWeightStandardizedConv2d(x, self.weights[module_key+'conv1.weight'].to(DEVICE), stride=self.stride)
        out = F.group_norm(out, 32, self.weights[module_key+'bn.weight'].to(DEVICE), self.weights[module_key+'bn.bias'].to(DEVICE))
        out = self.relu(out)
        out = FWeightStandardizedConv2d(out, self.weights[module_key+'conv2.weight'].to(DEVICE))
        out = F.group_norm(out, 32, self.weights[module_key+'bn.weight'].to(DEVICE), self.weights[module_key+'bn.bias'].to(DEVICE))

        if self.downsample is not None:
            identity = FWeightStandardizedConv2d(x, self.weights[module_key+'downsample.0.weight'].to(DEVICE), stride=self.stride, padding=0)
            identity = F.group_norm(identity, 32, self.weights[module_key+'downsample.1.weight'].to(DEVICE), self.weights[module_key+'downsample.1.bias'].to(DEVICE))
        out += identity
        out = self.relu(out)

        return out
"""

"""
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
"""

"""
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None, is_shortcut=False):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

        if is_shortcut:
            self.shortcut = nn.Sequential()
            if stride != 1 or inplanes != 1*planes:
                self.shortcut = nn.Sequential(
                    nn.Conv2d(inplanes, 1*planes,
                              kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(1*planes)
                )

"""
