'''
Modified from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
'''
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from utils.common import interpolateFeatures,_get_num_features
import torch.nn.functional as F
from utils.features import CrossStitch


__all__ = [
    'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
    'vgg19_bn', 'vgg19', 'vgg9_bn'
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## For models pre-trained on ImageNet
model_urls = {
   'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
   'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
   'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
   'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
   'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
   'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
   'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
   'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}


class VGG(nn.Module):
    source_feature_ids = [0, 1, 2, 3, 4]  # one for each internal representation
    target_route_ids = [0, 1, 2, 3, 4]  # one for each ConvolutionLayer

    def __init__(self, cfg, name='vgg', batch_norm=False, num_classes=10, init_weights=True, transfer_types=None,
                 source_info=None, ru_units=True, compact=False):
        super(VGG, self).__init__()
        self.name = name
        self.num_classes = num_classes
        self.channels = _get_num_features(name)

        if transfer_types and source_info:
            self.source_feature_ids, self.source_input_pass_id, decisioners, src_channels = source_info
            self.source_input_channels = [src_channels[ids] for ids in self.source_feature_ids]
            source_input_info = (self.source_feature_ids, self.source_input_channels)
            self.transfer_types = transfer_types
            self.ru_units = ru_units
            self.layers = make_layers(cfg, batch_norm, decisioners, transfer_types, source_input_info, ru_units)
        else:
            self.layers = make_layers(cfg, batch_norm)

        self.nlayers = len(self.layers)
        if compact:
            self.fc = nn.Linear(512, num_classes)
        else:
            self.fc = nn.Sequential(
                nn.Linear(512, 512),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(512, 512),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(512, num_classes),
            )
        if init_weights:
            self._initialize_weights()

    def forward(self, x, ref_features=None, pairs=None, weights=None):

        feats = []
        if ref_features is None:
            for layer in self.layers:
                x = layer(x)
                feats.append(x)
        else: # Auto-transfer
            assert pairs is not None
            # pairs format (target_layer_id, source_feature_id, transfer_type)
            # print(pairs)
            for l, (target_layer_id, source_feature_id, transfer_type) in enumerate(pairs):
                ref_input = None
                if source_feature_id != self.source_input_pass_id:
                    if source_feature_id == -1:
                        ref_input = [ref_features[sid] for sid in self.source_feature_ids]
                        # take expectation over the available features
                    else:
                        ref_input = ref_features[source_feature_id]
                else:
                    source_feature_id = self.source_feature_ids[l]
                    if transfer_type in ['combine', 'shared', 'block']:
                        ref_input = torch.zeros(
                            [x.size(0), self.source_input_channels[l], x.size(2), x.size(3)],
                            dtype=x.dtype, device=device)
                x = self.layers[l](x, (source_feature_id, transfer_type, ref_input))
                feats.append(x)

        x = F.avg_pool2d(x, x.size(3))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x, feats

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


# The customized model from https://arxiv.org/abs/1803.00443
class VGG_small(nn.Module):
    def __init__(self, features, num_classes=10, init_weights=True, transfer_types=None, source_info=None,
                 ru_units=True, lwf=False, num_source_cls=200, no_ft=False):
        super(VGG_small, self).__init__()
        self.features = features

        self.num_classes = num_classes
        if isinstance(num_classes, list):
            fcs = []
            for i in range(len(num_classes)):
                fcs.append(nn.Linear(512, num_classes[i]))
            self.classifier = nn.ModuleList(fcs)
        else:
            self.classifier = nn.Linear(512, num_classes)

        #self.classifier = nn.Linear(512, num_classes)
        self.lwf = lwf
        if self.lwf:
            self.lwf_lyr = nn.Linear(512, num_source_cls)

        self.no_ft = no_ft
        if self.no_ft:
            self.outputs_branch = nn.ModuleList(
                                    [nn.Linear(64, num_classes),
                                    nn.Linear(128, num_classes),
                                    nn.Linear(256, num_classes)])

        self.alphas = nn.ParameterList([nn.Parameter(torch.rand(3, 1, 1, 1)*0.1),
                nn.Parameter(torch.rand(3, 1, 1, 1)*0.1),
                nn.Parameter(torch.rand(3, 1, 1, 1)*0.1)])

        self.new_classifier = nn.Linear(512, num_classes)
        self.new_bn = nn.ModuleList()
        for layer in self.features:
            if isinstance(layer, nn.BatchNorm2d):
                self.new_bn.append(nn.BatchNorm2d(layer.num_features))

        self.w1 = nn.Linear(64, 16)
        self.w2 = nn.Linear(128, 32)
        self.w3 = nn.Linear(256, 64)
        self.w = nn.ModuleList([self.w1, self.w2, self.w3])
        if init_weights:
            self._initialize_weights()

    def forward(self, x, idx=-1):
        feat = []
        for layer in self.features:
            if isinstance(layer, nn.MaxPool2d):
                feat.append(x)
            x = layer(x)
        x = F.avg_pool2d(x, x.size(3))
        x = x.view(x.size(0), -1)

        if self.lwf:
            old_out = self.lwf_lyr(x)

        if isinstance(self.num_classes, list):
            x = self.classifier[idx](x)
        else:
            x = self.classifier(x)

        
        if self.lwf:
            return x, feat, old_out
        else:
            return x, feat

    def forward_with_features(self, x):
        return self.forward(x)

    def forward_with_combine_features(self, x, fs, metanet):
        return self.combine_forward(x, fs, metanet, 0)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class ConvolutionLayer(nn.Module): # Shareable Convolutional Layer

    def __init__(self, id, in_channels, out_channels, layers=1, batch_norm=False, decisioner=None, transfer_types=None,
                 source_input_info=None, ru_units=True):
        super(ConvolutionLayer, self).__init__()
        self.id = id
        self.transfer_types = transfer_types
        self.decisioner = decisioner
        self.route_wts = []
        self.ru_units = ru_units

        cblocks = []
        self.batch_norm = batch_norm
        _in_channels, _out_channels = in_channels, out_channels
        for l in range(layers-1):
            cblocks.append(conv3x3(_in_channels, _out_channels))
            if self.batch_norm:
                cblocks.append(nn.BatchNorm2d(_out_channels))
            cblocks.append(nn.ReLU(inplace=True))
            _in_channels = _out_channels
        self.blocks = nn.Sequential(*cblocks)

        self.conv1 = conv3x3(_in_channels, _out_channels)
        if self.batch_norm:
            self.bn1 = nn.BatchNorm2d(_out_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        if self.ru_units:
            self.ru = nn.ModuleList([])

        self.cs = nn.ModuleDict({})
        self.id2mod = None
        if transfer_types:
            print('Layer ' + str(id) + ' using ' + str(transfer_types) + '.')
            assert source_input_info is not None
            source_feature_ids, source_input_channels = source_input_info
            self.id2mod = {ids: i for i, ids in enumerate(source_feature_ids)}  # maps arm d to the module
            self.id2mod[-1] = -1
            self.source_input_channels = source_input_channels  # [64, 64, 128, 256, 512]
            # self.src_chan_vec = [200704, 200704, 50176, 12544]  # This has to be computed for each dataset

            if self.ru_units:
                self.ru.extend([nn.Sequential(
                    nn.Conv2d(src_channels, out_channels, 1),
                    nn.BatchNorm2d(out_channels),
                ) for src_channels in self.source_input_channels])
            for transfer_type in set(self.transfer_types):
                if transfer_type == 'xstitch':
                    self.cs['xstitch'] = nn.ModuleList(
                        [CrossStitch(id, i, in_channels) for i in range(len(self.source_input_channels))])
                else:
                    raise Exception("No valid transfer type found. Supported type: Cross Stitch")

    def forward(self, x, source_x = None):
        ref_id, transfer_type, ref_input = -1, '', None
        if source_x is not None:
            src_id, transfer_type, ref_input = source_x
            ref_id = self.id2mod[src_id] if self.id2mod else -1

        out =  self.blocks(x)
        out = self.conv1(out)
        if self.batch_norm:
            out = self.bn1(out)
        # Add source features here with ru units
        input = out
        if ref_input is not None:
            if self.ru_units:
                if type(ref_input) is not list:
                    ref_input = interpolateFeatures(ref_input, input.size(3)) # Matches the last two dimensions
                    ref_input = self.ru[ref_id](ref_input)
                else:
                    ref_input = [interpolateFeatures(rin, input.size(3)) for rin in ref_input]
                    ref_input = [rproj(rin) for (rin, rproj) in zip(ref_input, self.ru)]
                    ref_input.append(torch.zeros_like(input)) # PASS option or sentinel attention
            if ref_id == -1 and transfer_type != 'attention':
                # take expectation over the ref_inputs
                ref_input = torch.stack(ref_input,dim=-1)
                p = torch.from_numpy(self.decisioner.probabilities).float().to(input.device)
                ref_input = torch.matmul(ref_input,p)
                ref_id = self.id-1
            if transfer_type in ['xstitch']:
                # output=self.cs[transfer_type][ref_id](input, self.ru[ref_id](ref_input))
                out = self.cs[transfer_type][ref_id](input, ref_input)
            else:
                raise Exception("No valid transfer type found. Supported type: Cross Stitch")
        out = self.relu1(out)
        out = self.pool1(out)
        return out


def make_layers(cfg, batch_norm=False, decisioner=None, transfer_types=None, source_input_info=None, ru_units=True):
    layers = []
    in_channels = 3
    out_channels = 3
    mod_ct = 0
    block_ct = 0
    for i, v in enumerate(cfg):
        if v == 'M':
            assert mod_ct > 0, 'Invalid configuration for the VGG Network'
            # layers.append(ConvolutionLayer(in_channels, out_channels, layers=mod_ct, batch_norm=batch_norm))
            dec = None
            if decisioner:
                dec = decisioner[block_ct]
            block_ct += 1
            layers.append(
                ConvolutionLayer(block_ct, in_channels, out_channels, layers=mod_ct, batch_norm=batch_norm,
                                 decisioner=dec, transfer_types=transfer_types,
                                 source_input_info=source_input_info, ru_units=ru_units))
            in_channels = out_channels
            mod_ct = 0
        else:
            mod_ct += 1
            out_channels = int(v)
    return nn.Sequential(*layers)

def make_layers_old(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

cfg = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

cfg_small = {
    'A': [64, 'M', 128, 'M', 512, 'M'],
    'B': [64, 'M', 128, 'M', 256, 'M', 512, 'M'],
    'D': [64, 'M', 128, 'M', 256, 'M', 512, 'M', 512, 'M'],
}


def vgg4(pretrained=False, **kwargs):
    """
    VGG 4-layer model (configuration_small "A")
    """
    model = VGG(cfg_small['A'], name='vgg4', compact=True, **kwargs)
    model.source_feature_ids = [0, 1, 2]  # one for each internal representation
    model.target_route_ids = [0, 1, 2]  # one for each ConvolutionLayer
    if pretrained:
        print('Imagenet Pretrained network is not applicable for small VGG networks. Skipping this now.')
    return model


def vgg4_bn(pretrained=False, **kwargs):
    """
    VGG 4-layer model (configuration_small "A") with batch normalization
    """
    model = VGG(cfg_small['A'], name='vgg4_bn', batch_norm=True, compact=True, **kwargs)
    model.source_feature_ids = [0, 1, 2]  # one for each internal representation
    model.target_route_ids = [0, 1, 2]  # one for each ConvolutionLayer
    if pretrained:
        print('Imagenet Pretrained network is not applicable for small VGG networks. Skipping this now.')
    return model


def vgg9(pretrained=False, **kwargs):
    """
    VGG 9-layer model (configuration_small "B")
    """
    model = VGG(cfg_small['B'], name='vgg9', compact=True, **kwargs)
    if pretrained:
        print('Imagenet Pretrained network is not applicable for small VGG networks. Skipping this now.')
    return model


def vgg9_bn(pretrained=False, **kwargs):
    """
    VGG 9-layer model (configuration_small "B") with batch normalization
    """
    model = VGG(cfg_small['B'], name='vgg9_bn', batch_norm=True, compact=True, **kwargs)
    if pretrained:
        print('Imagenet Pretrained network is not applicable for small VGG networks. Skipping this now.')

    return model



def vgg11(pretrained=False, **kwargs):
    """
    VGG 11-layer model (configuration "A")
    """
    model = VGG(cfg['A'], name='vgg11', **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
    return model


def vgg11_bn(pretrained=False, **kwargs):
    """
    VGG 11-layer model (configuration "A") with batch normalization
    """
    model = VGG(cfg['A'], name='vgg11_bn', batch_norm=True, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))
    return model


def vgg13(pretrained=False, **kwargs):
    """
    VGG 13-layer model (configuration "B")
    """
    model = VGG(cfg['B'], name='vgg13', **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
    return model


def vgg13_bn(pretrained=False, **kwargs):
    """
    VGG 13-layer model (configuration "B") with batch normalization
    """
    model = VGG(cfg['B'], name='vgg13_bn', batch_norm=True, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
    return model


def vgg16(pretrained=False, **kwargs):
    """
    VGG 16-layer model (configuration "D")
    """
    model = VGG(cfg['D'], name='vgg16', **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
    return model


def vgg16_bn(pretrained=False, **kwargs):
    """
    VGG 16-layer model (configuration "D") with batch normalization
    """
    model = VGG(cfg['D'], name='vgg16_bn', batch_norm=True, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
    return model


def vgg19(pretrained=False, **kwargs):
    """
    VGG 19-layer model (configuration "E")
    """
    model = VGG(cfg['E'], name='vgg19', **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
    return model


def vgg19_bn(pretrained=False, **kwargs):
    """
    VGG 19-layer model (configuration 'E') with batch normalization
    """
    model = VGG(cfg['E'], name='vgg19_bn', batch_norm=True, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
    return model




if __name__ == "__main__":
    pass
#    x = torch.Tensor(5,3,32,32)
#    net = vgg4_bn()
#    y, feat = net(x)
#    
#    print (y.size())
#    print()
#    for i in range(len(feat)):
#        print (feat[i].size())


