# Based on code taken from https://github.com/facebookresearch/open_lth

# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch.nn as nn
import torch.nn.functional as F
from Layers import layers


class Block(nn.Module):
    """A ResNet block."""

    def __init__(self, f_in: int, f_out: int, downsample=False):
        super(Block, self).__init__()

        stride = 2 if downsample else 1
        self.conv1 = layers.Conv2d(f_in, f_out, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = layers.BatchNorm2d(f_out)
        self.conv2 = layers.Conv2d(f_out, f_out, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = layers.BatchNorm2d(f_out)

        # No parameters for shortcut connections.
        if downsample or f_in != f_out:
            self.shortcut = nn.Sequential(
                layers.Conv2d(f_in, f_out, kernel_size=1, stride=2, bias=False),
                layers.BatchNorm2d(f_out)
            )
        else:
            self.shortcut = layers.Identity2d(f_in)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)
    
class Block_simplex(nn.Module):
    """A ResNet block."""

    def __init__(self, f_in: int, f_out: int, downsample=False, num_models=3):
        super(Block_simplex, self).__init__()

        stride = 2 if downsample else 1
        self.conv1 = layers.Conv2d_simplex(f_in, f_out, kernel_size=3, stride=stride, padding=1, bias=False, num_models=num_models)
        self.bn1 = layers.BatchNorm2d(f_out)
        self.conv2 = layers.Conv2d_simplex(f_out, f_out, kernel_size=3, stride=1, padding=1, bias=False, num_models=num_models)
        self.bn2 = layers.BatchNorm2d(f_out)

        # No parameters for shortcut connections.
        if downsample or f_in != f_out:
            self.shortcut = nn.Sequential(
                layers.Conv2d(f_in, f_out, kernel_size=1, stride=2, bias=False),
                layers.BatchNorm2d(f_out)
            )
        else:
            self.shortcut = layers.Identity2d(f_in)
    
    def freeze_bn(self):
        self.bn1.track_running_stats = False
        self.bn1.eval()
        self.bn2.track_running_stats = False
        self.bn2.eval()
        
        for module in self.shortcut.modules():
            if isinstance(module, nn.BatchNorm2d):
                module.track_running_stats = False
                module.eval()        
    
    def freeze(self, vertex=0):
        self.conv1.freeze(vertex)
        self.conv2.freeze(vertex)
        if isinstance(self.shortcut, layers.Identity2d):
            self.shortcut.freeze()
        else:
            for param in self.shortcut.parameters():
                param.requires_grad = False
                param.grad = None
    
    def forward(self, x, alpha=[1.0, 1.0, 1.0], vertex=0, fixed=False):
        out = F.relu(self.bn1(self.conv1(x, alpha=alpha, vertex=vertex, fixed=fixed)))
        out = self.bn2(self.conv2(out, alpha=alpha, vertex=vertex, fixed=fixed))
        out += self.shortcut(x)
        return F.relu(out)
    
class Block_bezier(nn.Module):
    """A ResNet block."""

    def __init__(self, f_in: int, f_out: int, downsample=False, num_models=3):
        super(Block_bezier, self).__init__()

        stride = 2 if downsample else 1
        self.conv1 = layers.Conv2d_bezier(f_in, f_out, kernel_size=3, stride=stride, padding=1, bias=False, num_models=num_models)
        self.bn1 = layers.BatchNorm2d(f_out)
        self.conv2 = layers.Conv2d_bezier(f_out, f_out, kernel_size=3, stride=1, padding=1, bias=False, num_models=num_models)
        self.bn2 = layers.BatchNorm2d(f_out)

        # No parameters for shortcut connections.
        if downsample or f_in != f_out:
            self.shortcut = nn.Sequential(
                layers.Conv2d(f_in, f_out, kernel_size=1, stride=2, bias=False),
                layers.BatchNorm2d(f_out)
            )
        else:
            self.shortcut = layers.Identity2d(f_in)
    
    def freeze_bn(self):
        self.bn1.track_running_stats = False
        self.bn1.eval()
        self.bn2.track_running_stats = False
        self.bn2.eval()
        
        for module in self.shortcut.modules():
            if isinstance(module, nn.BatchNorm2d):
                module.track_running_stats = False
                module.eval()        
    
    def freeze(self, vertex=0):
        self.conv1.freeze(vertex)
        self.conv2.freeze(vertex)
        if isinstance(self.shortcut, layers.Identity2d):
            self.shortcut.freeze()
        else:
            for param in self.shortcut.parameters():
                param.requires_grad = False
                param.grad = None
    
    def forward(self, x, lambda_=0.5, vertex=0, fixed=False):
        out = F.relu(self.bn1(self.conv1(x, lambda_=lambda_, vertex=vertex, fixed=fixed)))
        out = self.bn2(self.conv2(out, lambda_=lambda_, vertex=vertex, fixed=fixed))
        out += self.shortcut(x)
        return F.relu(out)
    
class Block_simplex_v2(nn.Module):
    """A ResNet block."""

    def __init__(self, f_in: int, f_out: int, downsample=False):
        super(Block_simplex_v2, self).__init__()

        self.downsample = downsample
        self.f_in = f_in
        self.f_out = f_out
        
        stride = 2 if downsample else 1
        self.conv1 = layers.Conv2d_simplex(f_in, f_out, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = layers.BatchNorm2d(f_out)
        self.conv2 = layers.Conv2d_simplex(f_out, f_out, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = layers.BatchNorm2d(f_out)

        # No parameters for shortcut connections.
        if downsample or f_in != f_out:
            self.shortcut1 = layers.Conv2d_simplex(f_in, f_out, kernel_size=1, stride=2, bias=False),
            self.shortcut2 = layers.BatchNorm2d(f_out)
        else:
            self.shortcut = layers.Identity2d(f_in)

    def forward(self, x, vertex=0):
        out = F.relu(self.bn1(self.conv1(x, vertex=vertex)))
        out = self.bn2(self.conv2(out, vertex=vertex))
        if self.downsample or self.f_in != self.f_out:
            out += self.shortcut2(self.shortcut1(x, vertex=vertex))
        else:
            out += self.shortcut(x)
        return F.relu(out)


class ResNet(nn.Module):
    """A residual neural network as originally designed for CIFAR-10."""
    
    def __init__(self, plan, num_classes, dense_classifier):
        super(ResNet, self).__init__()

        # Initial convolution.
        current_filters = plan[0][0]
        self.conv = layers.Conv2d(3, current_filters, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn = layers.BatchNorm2d(current_filters)

        # The subsequent blocks of the ResNet.
        blocks = []
        for segment_index, (filters, num_blocks) in enumerate(plan):
            for block_index in range(num_blocks):
                downsample = segment_index > 0 and block_index == 0
                blocks.append(Block(current_filters, filters, downsample))
                current_filters = filters

        self.blocks = nn.Sequential(*blocks)

        self.fc = layers.Linear(plan[-1][0], num_classes)
        if dense_classifier:
            self.fc = nn.Linear(plan[-1][0], num_classes)

        self._initialize_weights()


    def forward(self, x):
        out = F.relu(self.bn(self.conv(x)))
        out = self.blocks(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (layers.Linear, nn.Linear, layers.Conv2d)):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, layers.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
                
class ResNet_simplex(nn.Module):
    """A residual neural network as originally designed for CIFAR-10."""
    
    def __init__(self, plan, num_classes, dense_classifier, num_models=3):
        super(ResNet_simplex, self).__init__()

        # Initial convolution.
        current_filters = plan[0][0]
        self.conv = layers.Conv2d_simplex(3, current_filters, kernel_size=3, stride=1, padding=1, bias=False, num_models=num_models)
        self.bn = layers.BatchNorm2d(current_filters)

        # The subsequent blocks of the ResNet.
        blocks = []
        for segment_index, (filters, num_blocks) in enumerate(plan):
            for block_index in range(num_blocks):
                downsample = segment_index > 0 and block_index == 0
                blocks.append(Block_simplex(current_filters, filters, downsample, num_models))
                current_filters = filters

        self.blocks = nn.ModuleList(blocks)

        self.fc = layers.Linear_simplex(plan[-1][0], num_classes, num_models=num_models)
        # if dense_classifier:
        #     self.fc = nn.Linear_simplex(plan[-1][0], num_classes, num_models=num_models)

        self._initialize_weights()
        
    def freeze_bn(self):
        self.bn.track_running_stats = False
        self.bn.eval()
        
        for block in self.blocks:
            block.freeze_bn()

    def forward(self, x, alpha=[1.0, 1.0, 1.0], vertex=0, fixed=False):
        out = F.relu(self.bn(self.conv(x, alpha=alpha, vertex=vertex, fixed=fixed)))
        for block in self.blocks:
            out = block(out, alpha=alpha, vertex=vertex, fixed=fixed)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.fc(out, alpha=alpha, vertex=vertex, fixed=fixed)
        return out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (layers.Conv2d_simplex, layers.Linear_simplex, layers.Conv2d)):
                for i in m.weight:
                    nn.init.kaiming_normal_(i)
            elif isinstance(m, layers.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
             
class ResNet_bezier(nn.Module):
    """A residual neural network as originally designed for CIFAR-10."""
    
    def __init__(self, plan, num_classes, dense_classifier, num_models=3):
        super(ResNet_bezier, self).__init__()

        # Initial convolution.
        current_filters = plan[0][0]
        self.conv = layers.Conv2d_bezier(3, current_filters, kernel_size=3, stride=1, padding=1, bias=False, num_models=num_models)
        self.bn = layers.BatchNorm2d(current_filters)

        # The subsequent blocks of the ResNet.
        blocks = []
        for segment_index, (filters, num_blocks) in enumerate(plan):
            for block_index in range(num_blocks):
                downsample = segment_index > 0 and block_index == 0
                blocks.append(Block_bezier(current_filters, filters, downsample, num_models))
                current_filters = filters

        self.blocks = nn.ModuleList(blocks)

        self.fc = layers.Linear_bezier(plan[-1][0], num_classes, num_models=num_models)
        # if dense_classifier:
        #     self.fc = nn.Linear_simplex(plan[-1][0], num_classes, num_models=num_models)

        self._initialize_weights()
        
    def freeze_bn(self):
        self.bn.track_running_stats = False
        self.bn.eval()
        
        for block in self.blocks:
            block.freeze_bn()

    def forward(self, x, lambda_=0.5, vertex=0, fixed=False):
        out = F.relu(self.bn(self.conv(x, lambda_=lambda_, vertex=vertex, fixed=fixed)))
        for block in self.blocks:
            out = block(out, lambda_=lambda_, vertex=vertex, fixed=fixed)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.fc(out, lambda_=lambda_, vertex=vertex, fixed=fixed)
        return out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (layers.Conv2d_bezier, layers.Linear_bezier)):
                for i in m.weight:
                    nn.init.kaiming_normal_(i)
            elif isinstance(m, layers.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                   
                
class ResNet_simplex_v2(nn.Module):
    """A residual neural network as originally designed for CIFAR-10."""
    
    def __init__(self, plan, num_classes, dense_classifier, num_models=3):
        super(ResNet_simplex_v2, self).__init__()

        # Initial convolution.
        current_filters = plan[0][0]
        self.conv = layers.Conv2d(3, current_filters, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn = layers.BatchNorm2d(current_filters)

        # The subsequent blocks of the ResNet.
        blocks = []
        for segment_index, (filters, num_blocks) in enumerate(plan):
            for block_index in range(num_blocks):
                downsample = segment_index > 0 and block_index == 0
                blocks.append(Block(current_filters, filters, downsample))
                current_filters = filters

        self.blocks = nn.ModuleList(blocks)

        self.fc = layers.Linear_simplex(plan[-1][0], num_classes, num_models=num_models)
        # if dense_classifier:
        #     self.fc = nn.Linear_simplex(plan[-1][0], num_classes, num_models=num_models)

        self._initialize_weights()
        
    def freeze_bn(self):
        self.bn.track_running_stats = False
        self.bn.eval()
        
        for block in self.blocks:
            block.freeze_bn()

    def forward(self, x, alpha=[1.0, 1.0, 1.0], vertex=0, fixed=False):
        out = F.relu(self.bn(self.conv(x)))
        for block in self.blocks:
            out = block(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.fc(out, alpha=alpha, vertex=vertex, fixed=fixed)
        return out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (layers.Conv2d_simplex, layers.Linear_simplex, layers.Conv2d)):
                for i in m.weight:
                    nn.init.kaiming_normal_(i)
            elif isinstance(m, layers.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
                
def _plan(D, W):
    """The naming scheme for a ResNet is 'cifar_resnet_N[_W]'.

    The ResNet is structured as an initial convolutional layer followed by three "segments"
    and a linear output layer. Each segment consists of D blocks. Each block is two
    convolutional layers surrounded by a residual connection. Each layer in the first segment
    has W filters, each layer in the second segment has 32W filters, and each layer in the
    third segment has 64W filters.

    The name of a ResNet is 'cifar_resnet_N[_W]', where W is as described above.
    N is the total number of layers in the network: 2 + 6D.
    The default value of W is 16 if it isn't provided.

    For example, ResNet-20 has 20 layers. Exclusing the first convolutional layer and the final
    linear layer, there are 18 convolutional layers in the blocks. That means there are nine
    blocks, meaning there are three blocks per segment. Hence, D = 3.
    The name of the network would be 'cifar_resnet_20' or 'cifar_resnet_20_16'.

    => for ResNet-20, D=20, W=16
    """
    if (D - 2) % 3 != 0:
        raise ValueError('Invalid ResNet depth: {}'.format(D))
    D = (D - 2) // 6
    plan = [(W, D), (2*W, D), (4*W, D)]

    return plan

def _resnet(arch, plan, num_classes, dense_classifier, pretrained):
    model = ResNet(plan, num_classes, dense_classifier)
    if pretrained:
        pretrained_path = 'Models/pretrained/{}-lottery.pt'.format(arch)
        pretrained_dict = torch.load(pretrained_path)
        model_dict = model.state_dict()
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
    return model
def _resnet_simplex(arch, plan, num_classes, dense_classifier, pretrained, num_models=3):
    model = ResNet_simplex(plan, num_classes, dense_classifier, num_models)
    if pretrained:
        pretrained_path = 'Models/pretrained/{}-lottery.pt'.format(arch)
        pretrained_dict = torch.load(pretrained_path)
        model_dict = model.state_dict()
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
    return model
def _resnet_simplex_v2(arch, plan, num_classes, dense_classifier, pretrained, num_models=3):
    model = ResNet_simplex_v2(plan, num_classes, dense_classifier, num_models)
    if pretrained:
        pretrained_path = 'Models/pretrained/{}-lottery.pt'.format(arch)
        pretrained_dict = torch.load(pretrained_path)
        model_dict = model.state_dict()
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
    return model
def _resnet_bezier(arch, plan, num_classes, dense_classifier, pretrained, num_models=3):
    model = ResNet_bezier(plan, num_classes, dense_classifier, num_models)
    if pretrained:
        pretrained_path = 'Models/pretrained/{}-lottery.pt'.format(arch)
        pretrained_dict = torch.load(pretrained_path)
        model_dict = model.state_dict()
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
    return model

# ResNet Models
def resnet20(input_shape, num_classes, dense_classifier=False, pretrained=False):
    plan = _plan(20, 16)
    return _resnet('resnet20', plan, num_classes, dense_classifier, pretrained)

def resnet20_simplex(input_shape, num_classes, dense_classifier=False, pretrained=False, num_models=3):
    plan = _plan(20, 16)
    return _resnet_simplex('resnet20', plan, num_classes, dense_classifier, pretrained, num_models)

def resnet20_bezier(input_shape, num_classes, dense_classifier=False, pretrained=False, num_models=3):
    plan = _plan(20, 16)
    return _resnet_bezier('resnet20', plan, num_classes, dense_classifier, pretrained, num_models)

def resnet20_simplex_v2(input_shape, num_classes, dense_classifier=False, pretrained=False, num_models=3):
    plan = _plan(20, 16)
    return _resnet_simplex_v2('resnet20', plan, num_classes, dense_classifier, pretrained, num_models)

def resnet32(input_shape, num_classes, dense_classifier=False, pretrained=False):
    plan = _plan(32, 16)
    return _resnet('resnet32', plan, num_classes, dense_classifier, pretrained)

def resnet32_simplex(input_shape, num_classes, dense_classifier=False, pretrained=False, num_models=3):
    plan = _plan(32, 16)
    return _resnet_simplex('resnet32', plan, num_classes, dense_classifier, pretrained, num_models)

def resnet44(input_shape, num_classes, dense_classifier=False, pretrained=False):
    plan = _plan(44, 16)
    return _resnet('resnet44', plan, num_classes, dense_classifier, pretrained)

def resnet56(input_shape, num_classes, dense_classifier=False, pretrained=False):
    plan = _plan(56, 16)
    return _resnet('resnet56', plan, num_classes, dense_classifier, pretrained)

def resnet110(input_shape, num_classes, dense_classifier=False, pretrained=False):
    plan = _plan(110, 16)
    return _resnet('resnet110', plan, num_classes, dense_classifier, pretrained)

def resnet1202(input_shape, num_classes, dense_classifier=False, pretrained=False):
    plan = _plan(1202, 16)
    return _resnet('resnet1202', plan, num_classes, dense_classifier, pretrained)

# Wide ResNet Models
def wide_resnet20(input_shape, num_classes, dense_classifier=False, pretrained=False):
    plan = _plan(20, 64)
    return _resnet('wide_resnet20', plan, num_classes, dense_classifier, pretrained)

def wide_resnet28x2(input_shape, num_classes, dense_classifier=False, pretrained=False):
    plan = _plan(29, 32)
    return _resnet('wide_resnet28x2', plan, num_classes, dense_classifier, pretrained)

def wide_resnet28x4(input_shape, num_classes, dense_classifier=False, pretrained=False):
    plan = _plan(29, 64)
    return _resnet('wide_resnet28x4', plan, num_classes, dense_classifier, pretrained)

def wide_resnet28x10(input_shape, num_classes, dense_classifier=False, pretrained=False):
    plan = _plan(29, 160)
    return _resnet('wide_resnet28x10', plan, num_classes, dense_classifier, pretrained)

def wide_resnet32(input_shape, num_classes, dense_classifier=False, pretrained=False):
    plan = _plan(32, 32)
    return _resnet('wide_resnet32', plan, num_classes, dense_classifier, pretrained)

def wide_resnet32x4(input_shape, num_classes, dense_classifier=False, pretrained=False):
    plan = _plan(32, 64)
    return _resnet('wide_resnet32x4', plan, num_classes, dense_classifier, pretrained)

def wide_resnet32x10(input_shape, num_classes, dense_classifier=False, pretrained=False):
    plan = _plan(32, 160)
    return _resnet('wide_resnet32x10', plan, num_classes, dense_classifier, pretrained)

def wide_resnet44(input_shape, num_classes, dense_classifier=False, pretrained=False):
    plan = _plan(44, 32)
    return _resnet('wide_resnet44', plan, num_classes, dense_classifier, pretrained)

def wide_resnet56(input_shape, num_classes, dense_classifier=False, pretrained=False):
    plan = _plan(56, 32)
    return _resnet('wide_resnet56', plan, num_classes, dense_classifier, pretrained)

def wide_resnet110(input_shape, num_classes, dense_classifier=False, pretrained=False):
    plan = _plan(110, 32)
    return _resnet('wide_resnet110', plan, num_classes, dense_classifier, pretrained)

def wide_resnet1202(input_shape, num_classes, dense_classifier=False, pretrained=False):
    plan = _plan(1202, 32)
    return _resnet('wide_resnet1202', plan, num_classes, dense_classifier, pretrained)
