import torch
import torch.nn as nn
import os
import pdb
from torch.nn import functional as F
import copy
from utils.conf import get_device

def APDConv2d(self, x, key_weights, task_id, stride=1, padding=1, dilation=1, dims=[], module_key='', info='weights', tad_thr=0, device='cuda'):
    key = module_key + '__' + info
    if type(task_id) == str:
        tsh= get_key_weights(self, key_weights, key, 'tsh.shared', dims).to(device)
        #pdb.set_trace()
        #self.register_parameter(name=_key, param=tsh)
        for tid in range(int(task_id)):
            msk = get_key_weights(self, key_weights, key, 'msk.t%d'%tid, [dims[0]]).to(device)
            tad = get_key_weights(self, key_weights, key, 'tad.t%d'%tid, dims, is_tad=True).to(device)
    else:
        tsh = key_weights['tsh']['%s_shared'%key].to(device)
        msk = key_weights['msk']['%s_t%d'%(key, task_id)].to(device)
        tad = key_weights['tad']['%s_t%d'%(key, task_id)].to(device)
    return F.conv2d(x, torch.einsum('i,i...->i...', msk, tsh) + l1_pruning(tad, tad_thr), stride=stride, padding=padding, dilation=dilation)

def l1_pruning(input, l1_hyp = 0):
    if l1_hyp == 0:
        return input
    else:
        hard_threshold = torch.abs(input) > l1_hyp
        return input * hard_threshold

def get_key_weights(self, key_weights, w_info, w_type, shape, is_tad=False):
    weights_type, weights_tid = w_type.split('.')
    key = '%s_%s'%(w_info, weights_tid)
    # get key w/o task_index
    # print('get_key: %s_%s'%(weights_type, key))
    if key in key_weights[weights_type].keys():
        return key_weights[weights_type][key]
    else:
        # tsh or msk
        if not is_tad:
            w = nn.Parameter(torch.ones(shape, requires_grad=True))
            self.register_parameter(name='%s/%s'%(weights_type, key), param=w)
            # initialization of mask: ones, otherwise: kaiming_normal
            nn.init.kaiming_normal_(w.data, mode='fan_out', nonlinearity='relu') if len(shape)>2 else None
            key_weights[weights_type][key] = w
        # tad
        else:
            tsh_key = '%s_%s'%(w_info, 'shared')
            get_v = copy.deepcopy(key_weights['tsh'][tsh_key])
            get_v.data /= 2
            key_weights[weights_type][key] = get_v
            self.register_parameter(name='%s/%s'%(weights_type, key), param=get_v)
            #key_weights[key] = copy.deepcopy(key_weights[tsh_key])
        return key_weights[weights_type][key]#, '%s/%s'%(weights_type, key)


class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None,
                 key_weights=None, tad_thr=0, module_key=''):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride
        self.tad_thr = tad_thr
        self.module_key = module_key
        self.key_weights = key_weights
        #self.device = get_device()
        self.device = 'cuda'
        self.inplanes = inplanes
        self.planes = planes
        self.dilation = dilation

        #if stride != 1 or self.inplanes != expansion*planes:
        if self.downsample:
            self.bn3 = nn.BatchNorm2d(1 * planes)

    def forward(self, inputs):
        x, task_id = inputs
        identity = x
        out = APDConv2d(self, x, self.key_weights, task_id, stride=self.stride, dilation=self.dilation,
                        dims=[self.planes, self.inplanes, 3, 3],
                        module_key=self.module_key, info='conv1_weights',
                        tad_thr=self.tad_thr, device=self.device)
        #print('%s, %s'%(self.module_key+'conv1.weights', [self.planes, self.inplanes, 3, 3]))
        out = self.relu(self.bn1(out))

        out = APDConv2d(self, out, self.key_weights, task_id, dilation=self.dilation, dims=[self.planes, self.planes, 3, 3],
                        module_key=self.module_key, info='conv2_weights',
                        tad_thr=self.tad_thr, device=self.device)
        #print('%s, %s'%(self.module_key+'.conv2.weights', [self.planes, self.inplanes, 3, 3]))
        out = self.bn2(out)

        # TODO deal with it
        if self.downsample:
            identity = APDConv2d(self, x, self.key_weights, task_id, stride=self.stride, padding=0, dilation=self.dilation,
                             dims=[self.planes, self.inplanes * 1, 1, 1],
                            module_key=self.module_key, info='downsample_weights',
                            tad_thr=self.tad_thr, device=self.device)
            identity = self.bn3(identity)
        out += identity
        out = self.relu(out)
        return out, task_id


class ResNet(nn.Module):
    #def __init__(self, block, num_blocks, device, num_classes=100, l1_hyp=0):
    def __init__(self, block, layers, num_classes=100, tad_thr=0, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        self.n_classes = num_classes
        self.device = get_device()
        self.key_weights = {'tsh':nn.ParameterDict(),
                            'msk':nn.ParameterDict(),
                            'tad':nn.ParameterDict(),
                            'tmp_sol':{}, # Freezed
                            }

        self.tad_thr = tad_thr

        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group

        ## CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1
        #self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        ## END

        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], module_key='layer1')
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0], module_key='layer2')
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1], module_key='layer3')
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2], module_key='layer4')
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, self.n_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

        #_ = self.forward(dummy.to(self.device), '%s'%(int(self.args.dataset.num_total_classes/self.args.dataset.num_classes_per_task)))
        #dummy = torch.rand((1, 3, 32, 32))
        #_ = self.forward(dummy.to(self.device), '20')

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False, module_key=''):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        #print(stride, self.inplanes, planes)
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = True
        layers = []
        _module_key = module_key + '_0'
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer,
                            key_weights=self.key_weights, tad_thr=self.tad_thr, module_key=_module_key))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            _module_key = module_key + '_%s'%i
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer,
                                key_weights=self.key_weights, tad_thr=self.tad_thr, module_key=_module_key))
        return nn.Sequential(*layers)

    def embed(self, x, task_id):
        x1 = APDConv2d(self, x, self.key_weights, task_id, module_key='conv1', info='weights',
                        dims=[self.base_width, 3, 3, 3], tad_thr=self.tad_thr, device=self.device)
        _x1 = F.relu(self.bn1(x1))
        x2, _ = self.layer1((_x1, task_id))
        x3, _ = self.layer2((x2, task_id))
        x4, _ = self.layer3((x3, task_id))
        x5, _ = self.layer4((x4, task_id))
        x = self.avgpool(x5)
        x = x.reshape(x.size(0), -1)
        return x, [x1, x2, x3, x4, x5]

    def forward(self, x, task_id, return_features=False, plot_metrics=False):
        x, features_list = self.embed(x, task_id)
        if return_features:
            return x
        else:
            x = self.fc(x)

        if plot_metrics:
            return x, features_list
        return x

    def get_params(self):
        params = []
        for pp in list(self.parameters()):
          # if pp.grad is not None:
          params.append(pp.view(-1))
        return torch.cat(params)

    def get_grads(self):
        grads = []
        for pp in list(self.parameters()):
            # if pp.grad is not None:
            grads.append(pp.grad.view(-1))
        return torch.cat(grads)


def ResNet18(**kwargs):
    return ResNet(BasicBlock, [2,2,2,2],  **kwargs)

def ResNet34(**kwargs):
    return ResNet(BasicBlock, [3,4,6,3], **kwargs)

def ResNet50(**kwargs):
    return ResNet(Bottleneck, [3,4,6,3], **kwargs)

def ResNet101(**kwargs):
    return ResNet(Bottleneck, [3,4,23,3], **kwargs)

def ResNet152(**kwargs):
    return ResNet(Bottleneck, [3,8,36,3], **kwargs)


def test():
    net = ResNet18()
    y = net(torch.randn(1,3,32,32))
    print(y.size())
