import pdb
import torch
import torch.nn as nn
import os
import torch.nn.functional as F
#from ..wage_modules.wage_initializer import wage_init_

# https://raw.githubusercontent.com/huyvnphan/PyTorch_CIFAR10/master/cifar10_models/resnet.py
__all__ = ['ResNet', 'resnet18', 'resnet50']
DEVICE='cuda'

def FConv2d(x, weight, bias=None, stride=1, padding=1, dilation=1, groups=1):
        return F.conv2d(x, weight, bias, stride, padding, dilation, groups)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None, module_key=''):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        self.bn = norm_layer(32, planes)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.module_key = module_key


    #def forward(self, x, weights, keylists):
    def forward(self, inputs):
        module_key = self.module_key
        x = inputs[0]
        self.weights = inputs[1]
        identity = x

        out = FConv2d(x, self.weights[module_key+'conv1.weight'].to(DEVICE), stride=self.stride)
        out = self.bn(out)
        out = self.relu(out)
        out = FConv2d(out, self.weights[module_key+'conv2.weight'].to(DEVICE))
        out = self.bn(out)
        if self.downsample is not None:
            identity = FConv2d(x, self.weights[module_key+'downsample.0.weight'].to(DEVICE), stride=self.stride, padding=0)
            identity = self.bn(identity)
        out += identity
        out = self.relu(out)
        return out

"""
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out
"""

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10, zero_init_residual=False,
                 groups=1, width_per_group=64, norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        self.n_classes = num_classes
        self.groups = groups
        self.base_width = width_per_group

        #self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.temp_block_key=0
        self.layer1_1 = self._make_layer1(block, 64, layers[0], stride=1)
        self.layer1_2 = self._make_layer2(block, 64, layers[0], stride=1)
        self.layer2_1 = self._make_layer1(block, 128, layers[1], stride=2)
        self.layer2_2 = self._make_layer2(block, 128, layers[1], stride=2)
        self.layer3_1 = self._make_layer1(block, 256, layers[2], stride=2)
        self.layer3_2 = self._make_layer2(block, 256, layers[2], stride=2)
        self.layer4_1 = self._make_layer1(block, 512, layers[3], stride=2)
        self.layer4_2 = self._make_layer2(block, 512, layers[3], stride=2)

        self.relu = nn.ReLU(inplace=True)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def _make_layer1(self, block, planes, blocks, stride=1):
        norm_layer = self._norm_layer
        self.temp_block_key += 1
        downsample = None
        previous_dilation = self.dilation
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample=True
        layers = []
        layers.append(block(stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer, module_key='module.layer%s.0.'%(self.temp_block_key)))
        self.inplanes = planes * block.expansion
        return nn.Sequential(*layers)

    def _make_layer2(self, block, planes, blocks, stride=1):
        norm_layer = self._norm_layer
        layers = []
        for _ in range(1, blocks):
            layers.append(block(groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer, module_key='module.layer%s.1.'%(self.temp_block_key)))
        return nn.Sequential(*layers)


    def l1_pruning(self, input, l1_prune):
        if l1_prune:
            return input
        else:
            hard_threshold = torch.abs(input) > self.args.l1_hyp
            return input * hard_threshold


    def embed(self, x, weights):
        #x1 = self.conv1(x, weights)
        keylists = list(weights.keys())
        l1key = [j for j in keylists if 'layer1' in j]
        l2key = [j for j in keylists if 'layer2' in j]
        l3key = [j for j in keylists if 'layer3' in j]
        l4key = [j for j in keylists if 'layer4' in j]

        x1 = FConv2d(x, weights[keylists[0]].to(DEVICE))
        x1 = F.conv2d(x, torch.einsum('i,i...->i...', self.add_masks_all[vp[task_idx][add_idx]], w) + self.l1_pruning(self.add_vars_all[vp[task_idx][add_idx]], l1_prune), b, stride=param[4], padding=param[5])
        idx += 2
        add_idx += 1

        #x = self.bn(x1, weights)
        x = F.group_norm(x1, 32, weights[keylists[1]].to(DEVICE), weights[keylists[2]].to(DEVICE))
        x = self.relu(x)
        x1_1 = self.layer1_1((x, weights))
        x1_2 = self.layer1_2((x1_1, weights))
        x2_1 = self.layer2_1((x1_2, weights))
        x2_2 = self.layer2_2((x2_1, weights))
        x3_1 = self.layer3_1((x2_2, weights))
        x3_2 = self.layer3_2((x3_1, weights))
        x4_1 = self.layer4_1((x3_2, weights))
        x4_2 = self.layer4_2((x4_1, weights))

        x = self.avgpool(x4_2)
        x = x.reshape(x.size(0), -1)
        return x

    def forward(self, x, weights):
        x = self.embed(x, weights)
        x = F.linear(x, weights['module.fc.weight'].to(DEVICE), bias=weights['module.fc.bias'].to(DEVICE))
        if 'apd':
            w, b = vars[idx], vars[idx + 1]
            #x = F.linear(x, add_masks[add_idx][task_idx].view(-1,1) * w + add_vars[20*add_idx+task_idx], b)
            x = F.linear(x, torch.einsum('i,i...->i...', self.add_masks_all[vp[task_idx][add_idx]], w) + self.l1_pruning(self.add_vars_all[vp[task_idx][add_idx]], l1_prune), b)
            idx += 2
            add_idx += 1
        return x


    def get_params(self):
        params = []
        for pp in list(self.parameters()):
          # if pp.grad is not None:
          params.append(pp.view(-1))
        return torch.cat(params)

    def get_grads(self):
        grads = []
        for pp in list(self.parameters()):
            # if pp.grad is not None:
            grads.append(pp.grad.view(-1))
        return torch.cat(grads)


def _resnet(arch, block, layers, device, **kwargs):
    model = ResNet(block, layers, **kwargs)
    return model

def resnet18(device='cpu', **kwargs):
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], device, **kwargs)

def resnet50(device='cpu', **kwargs):
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], device, **kwargs)
