import torch
import torch.nn as nn
from torch.nn import init
from .layers import *
from .param_bank import *

class Block(nn.Module):
    def __init__(self, in_planes, out_planes, stride, bank=None, ensemble=None, naive=False):
        super(Block, self).__init__()
        self.ensemble = ensemble
        self.naive = naive
        if self.ensemble:
            self.bn1 = ConditionalBatchNorm2d(in_planes, ensemble)
        else:
            self.bn1 = nn.BatchNorm2d(in_planes)
        if naive:  self.conv1 = NaiveConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, ensemble=ensemble)
        elif bank: self.conv1 = SConv2d(bank, in_planes, out_planes, kernel_size=3, stride=stride, padding=1, ensemble=ensemble)
        else: self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

        if self.ensemble:
            self.bn2 = ConditionalBatchNorm2d(out_planes, ensemble)
        else:
            self.bn2 = nn.BatchNorm2d(out_planes)
        if naive: self.conv2 = NaiveConv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, ensemble=ensemble)
        elif bank: self.conv2 = SConv2d(bank, out_planes, out_planes, kernel_size=3, stride=1, padding=1, ensemble=ensemble)
        else: self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False)

        self.relu = nn.ReLU(inplace=True)
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = None
        if not self.equalInOut:
            if naive: self.convShortcut = NaiveConv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, ensemble=ensemble)
            elif bank: self.convShortcut = SConv2d(bank, in_planes, out_planes, kernel_size=1, stride=stride, padding=0, ensemble=ensemble)
            else: self.convShortcut = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)

    def forward(self, x):
        if isinstance(x,tuple) and not self.naive:
            x, ensemble_idx, interpolate = x
            residual = x
            out = self.relu(self.bn1(x, ensemble_idx))
            if not self.equalInOut: residual = out
            out = self.conv2(self.relu(self.bn2(self.conv1(out,ensemble_idx, interpolate),ensemble_idx, interpolate)),ensemble_idx, interpolate)
            if self.convShortcut is not None: residual = self.convShortcut(residual, ensemble_idx, interpolate)
            return (out + residual, ensemble_idx, interpolate)
        elif isinstance(x,tuple):
            x, ensemble_idx = x
            residual = x
            out = self.relu(self.bn1(x, ensemble_idx))
            if not self.equalInOut: residual = out
            out = self.conv2(self.relu(self.bn2(self.conv1(out,ensemble_idx),ensemble_idx)),ensemble_idx)
            if self.convShortcut is not None: residual = self.convShortcut(residual, ensemble_idx)
            return (out + residual, ensemble_idx)
        else:
            residual = x
            out = self.relu(self.bn1(x))
            if not self.equalInOut: residual = out
            out = self.conv2(self.relu(self.bn2(self.conv1(out))))
            if self.convShortcut is not None: residual = self.convShortcut(residual)
            return out + residual

class SWRN(nn.Module):
    def __init__(self, share_type, upsample_type, upsample_window, depth, width, num_templates, max_params, num_classes, groups, ensemble=None, mimo_heads=None, share_class=True, shortcut=True, naive=False, interpolate=None, ensemble_round_robin=False):
        super(SWRN, self).__init__()

        n_channels = [16, 16*width, 32*width, 64*width]
        assert((depth - 4) % 6 == 0)
        num_blocks = (depth - 4) // 6
        layers_per_bank = 2*(num_blocks-1)
        print ('SWRN : Depth : {} , Widen Factor : {}, Templates per Group : {}'.format(depth, width, num_templates))

        self.num_classes = num_classes
        self.num_templates = num_templates
        self.bank = None
        self.ensemble = ensemble
        self.mimo_heads = mimo_heads
        self.share_class = share_class
        self.naive = naive
        self.interpolate = interpolate
        if share_type != 'none':
            self.bank = ParameterGroups(groups, share_type, upsample_type, upsample_window, max_params, num_templates, ensemble=ensemble, shortcut=shortcut, ensemble_round_robin=ensemble_round_robin)

        if self.mimo_heads is None:
            in_dims = 3
        else:
            in_dims = 3 * self.mimo_heads

        if naive: self.conv_3x3 = NaiveConv2d(in_dims, n_channels[0], kernel_size=3, stride=1, padding=1, ensemble=ensemble)
        elif self.bank: self.conv_3x3 = SConv2d(self.bank, in_dims, n_channels[0], kernel_size=3, stride=1, padding=1, ensemble=ensemble)
        else: self.conv_3x3 = nn.Conv2d(in_dims, n_channels[0], kernel_size=3, stride=1, padding=1, bias=False)

        self.stage_1 = self._make_layer(n_channels[0], n_channels[1], num_blocks, 1, ensemble=ensemble, naive=naive)
        self.stage_2 = self._make_layer(n_channels[1], n_channels[2], num_blocks, 2, ensemble=ensemble, naive=naive)
        self.stage_3 = self._make_layer(n_channels[2], n_channels[3], num_blocks, 2, ensemble=ensemble, naive=naive)

        if ensemble is not None:
            self.last_bn = ConditionalBatchNorm2d(n_channels[3],ensemble)
        else:
            self.last_bn = nn.BatchNorm2d(n_channels[3])
        self.lastact = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(8)

        if mimo_heads is None:
            if naive: self.classifier = NaiveConv2d(n_channels[3], num_classes, kernel_size=1, ensemble=ensemble)
            elif self.bank: self.classifier = SConv2d(self.bank, n_channels[3], num_classes, kernel_size=1, ensemble=ensemble)
            else: self.classifier = nn.Conv2d(n_channels[3], num_classes, kernel_size=1)
        else:
            if self.bank and self.share_class:
                self.classifier = nn.ModuleList([SConv2d(self.bank, n_channels[3], num_classes, kernel_size=1, ensemble=ensemble) for _ in range(mimo_heads)])
            elif self.bank:
                self.classifier = nn.ModuleList([nn.ModuleList([nn.Conv2d(n_channels[3], num_classes, kernel_size=1) for _ in range(mimo_heads)]) for _ in range(ensemble)])
            else:
                self.classifier = nn.ModuleList([nn.Conv2d(n_channels[3], num_classes, kernel_size=1) for _ in range(mimo_heads)])


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight)
                m.bias.data.zero_()

        if self.bank:
            self.bank.setup_bank()
            if self.ensemble is not None:
                coefficients = self.get_coefficients()
                self.coefficient_weight_splits = [c.coefficients.numel() for c in coefficients]
                self.total_coefficients = sum(self.coefficient_weight_splits)
                #coeff_params = torch.zeros((len(self.coefficient_weight_splits) * self.ensemble,5))
                #nn.init.orthogonal_(coeff_params)
                #coeff_params = torch.reshape(coeff_params, (self.ensemble, self.total_coefficients))
                self.coeff_embeddings = torch.nn.Embedding(self.ensemble, self.total_coefficients)
                #self.coeff_embeddings = torch.nn.Embedding.from_pretrained(coeff_params)

    def _make_layer(self, in_planes, out_planes, num_blocks, stride=1, ensemble=None, naive=False):
        blocks = []
        blocks.append(Block(in_planes, out_planes, stride, self.bank, ensemble=ensemble, naive=naive))
        for i in range(1, num_blocks): blocks.append(Block(out_planes, out_planes, 1, self.bank, ensemble=ensemble, naive=naive))
        return nn.Sequential(*blocks)

    def forward(self, x, ensemble_idx = None):
        if ensemble_idx is None:
             if self.mimo_heads is not None:
                 x = torch.reshape(x, (-1,self.mimo_heads) + x.shape[1:])
                 x = torch.reshape(x, (x.shape[0],-1) + x.shape[3:])
             x = self.conv_3x3(x)
             x = self.stage_1(x)
             x = self.stage_2(x)
             x = self.stage_3(x)
             x = self.last_bn(x)
             x = self.lastact(x)
             x = self.avgpool(x)
             if self.mimo_heads is not None:
                x = torch.stack([head(x) for head in self.classifier],dim=1)
                x = torch.reshape(x, (-1,) + x.shape[2:])
             else:
                 x = self.classifier(x)
             return x.view(x.size(0),-1)
        else:
             if not self.naive:
                member_coefficients = self.coeff_embeddings.weight
                split_member_coefficients = torch.split(member_coefficients,self.coefficient_weight_splits,dim=1)
                self.set_coefficients(split_member_coefficients)
             x = torch.split(x,1,dim=1)
             x = torch.cat(x).squeeze()
             if self.mimo_heads is not None:
                 x = torch.reshape(x, (-1,self.mimo_heads) + x.shape[1:])
                 x = torch.reshape(x, (x.shape[0],-1) + x.shape[3:])
             if not self.naive:
                x = self.conv_3x3(x, ensemble_idx,self.interpolate)
                x, _, _ = self.stage_1((x,ensemble_idx, self.interpolate))
                x, _, _ = self.stage_2((x,ensemble_idx, self.interpolate))
                x, _, _ = self.stage_3((x, ensemble_idx, self.interpolate))
             else:
                x = self.conv_3x3(x, ensemble_idx)
                x, _ = self.stage_1((x,ensemble_idx))
                x, _ = self.stage_2((x,ensemble_idx))
                x, _ = self.stage_3((x, ensemble_idx))
             x = self.last_bn(x, ensemble_idx, self.interpolate)
             x = self.lastact(x)
             x = self.avgpool(x)
             if self.mimo_heads is not None:
                if not self.share_class:
                    x = torch.split(x, x.shape[0]//self.ensemble)
                    x_list = []
                    for member_idx,  split in enumerate(x):
                        x_list.append( torch.stack([head(split) for head in self.classifier[member_idx]],dim=1))
                    x = torch.cat(x_list)
                    #x = torch.stack([head(x) for head in self.classifier],dim=1)
                else:
                    x = torch.stack([head(x, ensemble_idx) for head in self.classifier],dim=1)
                x = torch.reshape(x, (-1,) + x.shape[2:])
             else:
                 if not self.naive:
                    x = self.classifier(x,ensemble_idx, self.interpolate)
                 else:
                    x = self.classifier(x, ensemble_idx)
             return x.view(x.size(0),-1)

    def get_coefficients(self):
        coefficients = []
        for m in self.modules():
            if isinstance(m,SConv2d) or isinstance(m,SLinear):
                if hasattr(m,'coefficients'):
                    coefficients.append(m.coefficients)
        return coefficients

    def set_coefficients(self, coefficients):
        idx = 0
        for m in self.modules():
            if isinstance(m,SConv2d) or isinstance(m,SLinear):
                if hasattr(m,'coefficients'):
                    m.coefficients.coefficients = coefficients[idx]
                    idx += 1


def swrn(share_type, upsample_type, upsample_window, depth, width, num_templates, max_params, num_classes=10, groups=None, ensemble=None, mimo_heads=None, share_class=True, shortcut=True, naive=False, interpolate=None, ensemble_round_robin=False, **kwargs):
    model = SWRN(share_type, upsample_type, upsample_window, depth, width, num_templates, max_params, num_classes, groups, ensemble, mimo_heads, share_class, shortcut, naive, interpolate=interpolate, ensemble_round_robin=ensemble_round_robin)
    return model


# ImageNet model.
# This is a ResNet v1.5-style model (stride 2 on 3x3 convolutions).
# In contrast to the above, this applies batchnorm/relu after convolution.

class ConvBNRelu(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
                 relu=True, bank=None, ensemble = None, naive=False):
        super().__init__()
        self.bank = bank
        if naive:
            self.conv = NaiveConv2d(in_channels, out_channels,
                                kernel_size=kernel_size, stride=stride,
                                padding=padding, ensemble=ensemble)
        elif self.bank:
            self.conv = SConv2d(self.bank, in_channels, out_channels,
                                kernel_size=kernel_size, stride=stride,
                                padding=padding, ensemble=ensemble)
        else:
            self.conv = nn.Conv2d(in_channels, out_channels,
                                  kernel_size=kernel_size,
                                  stride=stride, padding=padding,
                                  bias=False)
        if ensemble is not None:
            self.bn = ConditionalBatchNorm2d(out_channels, ensemble)
        else:
            self.bn = nn.BatchNorm2d(out_channels)
        if relu:
            self.relu = nn.ReLU(inplace=True)
        else:
            self.relu = None

    def forward(self, x, ensemble_idx = None):
        if ensemble_idx is not None:
            out = self.bn(self.conv(x,ensemble_idx),ensemble_idx)
        else:
            out = self.bn(self.conv(x))
        if self.relu is not None:
            out = self.relu(out)
        return out


class BottleneckBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, downsample, bank=None,
                 width=1, pool_residual=False, ensemble = None, naive=False):
        super().__init__()
        self.bank = bank
        self.out_channels = 4 * mid_channels
        # Width factor applies only to inner 3x3 convolution.
        mid_channels = int(mid_channels * width)
        self.ensemble = ensemble

        # Skip connection.
        if downsample:
            if pool_residual:
                pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
                conv = ConvBNRelu(
                    in_channels, self.out_channels, kernel_size=1,
                    stride=1, padding=0, relu=False, bank=self.bank,
                    ensemble=self.ensemble, naive=naive)
                self.skip_connection = nn.Sequential(pool, conv)
            else:
                self.skip_connection = ConvBNRelu(
                    in_channels, self.out_channels, kernel_size=1,
                    stride=2, padding=0, relu=False, bank=self.bank,
                    ensemble = self.ensemble, naive=naive)
        elif in_channels != self.out_channels:
            self.skip_connection = ConvBNRelu(
                in_channels, self.out_channels, kernel_size=1,
                stride=1, padding=0, relu=False, bank=self.bank,
                ensemble = self.ensemble, naive=naive)
        else:
            self.skip_connection = None

        # Main branch.
        self.in_conv = ConvBNRelu(
            in_channels, mid_channels, kernel_size=1,
            stride=1, padding=0, bank=self.bank, ensemble = self.ensemble, naive=naive)
        self.mid_conv = ConvBNRelu(
            mid_channels, mid_channels, kernel_size=3,
            stride=(2 if downsample else 1), padding=1, bank=self.bank,
            ensemble = self.ensemble, naive=naive)
        self.out_conv = ConvBNRelu(
            mid_channels, self.out_channels, kernel_size=1,
            stride=1, padding=0, relu=False, bank=self.bank,
            ensemble = self.ensemble, naive=naive)
        self.out_relu = nn.ReLU(inplace=True)

    def forward(self, x):
        ensemble_idx = None
        if isinstance(x, tuple):
            x, ensemble_idx = x
        if ensemble_idx is not None:
            if self.skip_connection is not None:
                residual = self.skip_connection(x, ensemble_idx)
            else:
                residual = x

            out = self.out_conv(self.mid_conv(self.in_conv(x, ensemble_idx),ensemble_idx), ensemble_idx)
            out += residual
            return (self.out_relu(out), ensemble_idx)
        else:
            if self.skip_connection is not None:
                residual = self.skip_connection(x)
            else:
                residual = x

            out = self.out_conv(self.mid_conv(self.in_conv(x)))
            out += residual
            return self.out_relu(out)

class ResNet(nn.Module):
    def __init__(self, block, module_sizes, module_channels, num_classes,
                    width=1, bank=None, pool_residual=False, ensemble = None, naive=False, ortho_init=False):
        super().__init__()
        self.bank = bank
        self.ensemble = ensemble
        self.naive = naive
        # Input trunk, Inception-style.
       # self.conv1 = ConvBNRelu(3, module_channels[0] // 2, kernel_size=3,
       #                         stride=2, padding=1, bank=self.bank,
       #                         ensemble = self.ensemble)
       # self.conv2 = ConvBNRelu(module_channels[0] // 2, module_channels[0] // 2,
       #                         kernel_size=3, stride=1, padding=1, bank=self.bank,
       #                         ensemble = self.ensemble)
       # self.conv3 = ConvBNRelu(module_channels[0] // 2, module_channels[0],
       #                         kernel_size=3, stride=1, padding=1, bank=self.bank,
       #                         ensemble = self.ensemble)
       # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Input trunk, ResNet style
        self.conv1 = ConvBNRelu(3, module_channels[0], kernel_size=7,
                                stride=2, padding=3, bank=self.bank,
                                ensemble = self.ensemble, naive=naive)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.naive = naive

        # Build the main network.
        modules = []
        out_channels = module_channels[0]
        for module_idx, (num_layers, mid_channels) in enumerate(zip(
                module_sizes, module_channels)):
            blocks = []
            for i in range(num_layers):
                in_channels = out_channels
                downsample = i == 0 and module_idx > 0
                b = block(in_channels, mid_channels, downsample,
                          bank=self.bank, width=width,
                          pool_residual=pool_residual,
                          ensemble = self.ensemble, naive=naive)
                out_channels = b.out_channels
                blocks.append(b)
            modules.append(nn.Sequential(*blocks))
        self.block_modules = nn.Sequential(*modules)

        # Output.
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        if naive:
            self.fc = NaiveConv2d(out_channels, num_classes, kernel_size=1,ensemble=self.ensemble)
        elif self.bank:
            #self.fc = SLinear(self.bank, out_channels, num_classes)
            self.fc = SConv2d(self.bank, out_channels, num_classes, kernel_size=1,ensemble=self.ensemble)
        else:
            #self.fc = nn.Linear(out_channels, num_classes)
            self.fc = nn.Conv2d(out_channels, num_classes, kernel_size=1)
        self._init_weights()

        if self.bank:
            self.bank.setup_bank()
            if self.ensemble is not None:
                coefficients = self.get_coefficients()
                self.coefficient_weight_splits = [c.coefficients.numel() for c in coefficients]
                self.total_coefficients = sum(self.coefficient_weight_splits)
                if ortho_init:
                    coeff_params = torch.zeros(self.ensemble, self.total_coefficients)
                    nn.init.orthogonal_(coeff_params)
                    self.coeff_embeddings = torch.nn.Embedding.from_pretrained(coeff_params)
                else:
                    self.coeff_embeddings = torch.nn.Embedding(self.ensemble, self.total_coefficients)

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight)
                init.constant_(m.bias, 0)
        # Zero initialize the last batchnorm in each residual branch.
        for m in self.modules():
            if isinstance(m, BottleneckBlock):
                if isinstance(m.out_conv.bn,ConditionalBatchNorm2d):
                    for bn in m.out_conv.bn.modules():
                        if isinstance(bn,nn.BatchNorm2d):
                            init.constant_(bn.weight,0)
                else:
                    init.constant_(m.out_conv.bn.weight, 0)


    def forward(self, x, ensemble_idx = None):
        if self.ensemble:
            if not self.naive:
                member_coefficients = self.coeff_embeddings.weight
                split_member_coefficients = torch.split(member_coefficients,self.coefficient_weight_splits,dim=1)
                self.set_coefficients(split_member_coefficients)
            x = torch.split(x,1,dim=1)
            x = torch.cat(x).squeeze()
        if ensemble_idx is None:
            #x = self.maxpool(self.conv3(self.conv2(self.conv1(x))))
            x = self.maxpool(self.conv1(x))
            x = self.block_modules(x)
            x = self.fc(self.avgpool(x))
        else:
            #x = self.maxpool(self.conv3(self.conv2(self.conv1(x, ensemble_idx),ensemble_idx),ensemble_idx))
            x = self.maxpool(self.conv1(x, ensemble_idx))
            x,_ = self.block_modules((x,ensemble_idx))
            x = self.fc(self.avgpool(x), ensemble_idx)
        x = x.view(x.size(0),-1)
        return x


    def get_coefficients(self):
        coefficients = []
        for m in self.modules():
            if isinstance(m,SConv2d) or isinstance(m,SLinear):
                if hasattr(m,'coefficients'):
                    coefficients.append(m.coefficients)
        return coefficients

    def set_coefficients(self, coefficients):
        idx = 0
        for m in self.modules():
            if isinstance(m,SConv2d) or isinstance(m,SLinear):
                if hasattr(m,'coefficients'):
                    m.coefficients.coefficients = coefficients[idx]
                    idx += 1

def swrn_imagenet(share_type, upsample_type, upsample_window, depth, width,
                  num_templates, max_params, num_classes=1000, groups=None, ensemble=None, shortcut=True, naive=False,ensemble_round_robin=False,ortho_init=False, **kwargs):
    """ResNet-50, with optional width (depth ignored for now, can generalize)."""
    bank = None
    if share_type != 'none':
        bank = ParameterGroups(groups, share_type, upsample_type,
                               upsample_window, max_params, num_templates, ensemble=ensemble, shortcut=shortcut, ensemble_round_robin=ensemble_round_robin)
    return ResNet(BottleneckBlock,
                  (3, 4, 6, 3), (64, 128, 256, 512),
                  num_classes, width=width, bank=bank,
                  pool_residual=False, ensemble=ensemble, naive=naive)



def swrn_imagenet_reduced(share_type, upsample_type, upsample_window, depth,
                          width, num_templates, max_params, num_classes=1000,
                          groups=None):
    """ResNet-50 with reduced numbers of filters."""
    bank = None
    if share_type != 'none':
        bank = ParameterGroups(groups, share_type, upsample_type,
                               upsample_window, max_params, num_templates)
    return ResNet(BottleneckBlock,
                  (3, 4, 6, 3), (16, 32, 64, 96),
                  num_classes, width=width, bank=bank,
                  pool_residual=False)


def swrn_imagenet17(share_type, upsample_type, upsample_window, depth, width,
                    num_templates, max_params, num_classes=1000, groups=None, ensemble=None, shortcut=True, naive=False,ortho_init=False,ensemble_round_robin=False, **kwargs):
    """Funny Wide ResNet-17 with bottleneck blocks."""
    bank = None
    if share_type != 'none':
        bank = ParameterGroups(groups, share_type, upsample_type,
                               upsample_window, max_params, num_templates, ensemble=ensemble, shortcut=shortcut, ensemble_round_robin=ensemble_round_robin)
    return ResNet(BottleneckBlock,
                  (1, 1, 1, 1), (64, 128, 256, 512),
                  num_classes, width=width, bank=bank,
                  pool_residual=False, ensemble=ensemble, naive=naive, ortho_init=ortho_init)
