import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, n_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, n_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, n_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, n_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, n_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


class DropoutBasicBlock(BasicBlock):
    def __init__(self, in_planes, planes, dropout, stride=1):
        super(DropoutBasicBlock, self).__init__(in_planes, planes, stride=stride)
        self.dropout = dropout

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.dropout(out, p=self.dropout, training=True, inplace=False)
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        out = F.dropout(out, p=self.dropout, training=True, inplace=False)
        return out


class DropoutBottleneck(Bottleneck):
    def __init__(self, in_planes, planes, dropout, stride=1):
        super(DropoutBottleneck, self).__init__(in_planes, planes, stride=stride)
        self.dropout = dropout

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.dropout(out, p=self.dropout, training=True, inplace=False)
        out = F.relu(self.bn2(self.conv2(out)))
        out = F.dropout(out, p=self.dropout, training=True, inplace=False)
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        out = F.dropout(out, p=self.dropout, training=True, inplace=False)
        return out


class DropoutResNet(nn.Module):
    def __init__(self, block, n_blocks, dropout, num_classes=10):
        super(DropoutResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, n_blocks[0], dropout, stride=1)
        self.layer2 = self._make_layer(block, 128, n_blocks[1], dropout, stride=2)
        self.layer3 = self._make_layer(block, 256, n_blocks[2], dropout, stride=2)
        self.layer4 = self._make_layer(block, 512, n_blocks[3], dropout, stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, dropout, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, dropout, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.dropout(out, p=self.dropout, training=True, inplace=False)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


class RadialConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, w_rho_init=-5, b_rho_init=-5, **kwargs):
        super().__init__(in_channels, out_channels, kernel_size, **kwargs)
        self.weight_rho = nn.Parameter(torch.full_like(self.weight, w_rho_init), requires_grad=True)
        if self.bias:
            self.bias_rho = nn.Parameter(torch.full_like(self.bias, b_rho_init), requires_grad=True)

    def forward(self, x):
        eps = torch.randn_like(self.weight)
        r = torch.randn(1, device=self.weight.device)
        weight = self.weight + F.softplus(self.weight_rho) * (eps / eps.norm() * r)
        if self.bias:
            bias = self.bias + F.softplus(self.bias_rho) * (eps / eps.norm() * r)
        else:
            bias = None
        return F.conv2d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)


class RadialBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(RadialBasicBlock, self).__init__()
        self.conv1 = RadialConv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = RadialConv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                RadialConv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class RadialBottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(RadialBottleneck, self).__init__()
        self.conv1 = RadialConv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = RadialConv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = RadialConv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                RadialConv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class RadialResNet(nn.Module):
    def __init__(self, block, n_blocks, num_classes=10):
        super(RadialResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = RadialConv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, n_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, n_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, n_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, n_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


class Rank1Conv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, n_models, kernel_size,
                 alpha_mu_init=1., alpha_rho_init=-3., gamma_mu_init=1., gamma_rho_init=-3., **kwargs):
        super().__init__(in_channels, out_channels, kernel_size, **kwargs)
        self.n_models = n_models
        alpha_mu = alpha_mu_init + F.softplus(torch.tensor(alpha_rho_init)) * torch.randn([n_models, in_channels])
        self.alpha_mu = nn.Parameter(alpha_mu, requires_grad=True)
        self.alpha_rho = nn.Parameter(torch.full([n_models, in_channels], alpha_rho_init), requires_grad=True)
        gamma_mu = gamma_mu_init + F.softplus(torch.tensor(gamma_rho_init)) * torch.randn([n_models, out_channels])
        self.gamma_mu = nn.Parameter(gamma_mu, requires_grad=True)
        self.gamma_rho = nn.Parameter(torch.full([n_models, out_channels], gamma_rho_init), requires_grad=True)

    def forward(self, x):
        examples_per_model = x.shape[0] // self.n_models
        alpha = Normal(self.alpha_mu, F.softplus(self.alpha_rho)).rsample([examples_per_model])
        gamma = Normal(self.gamma_mu, F.softplus(self.gamma_rho)).rsample([examples_per_model])
        alpha = alpha.transpose(0, 1).reshape(-1, self.in_channels).unsqueeze(-1).unsqueeze(-1)
        gamma = gamma.transpose(0, 1).reshape(-1, self.out_channels).unsqueeze(-1).unsqueeze(-1)
        output = F.conv2d(x * alpha, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
        output *= gamma
        return output


class Rank1BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, n_models, stride=1):
        super(Rank1BasicBlock, self).__init__()
        self.conv1 = Rank1Conv2d(in_planes, planes, n_models, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = Rank1Conv2d(planes, planes, n_models, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                Rank1Conv2d(in_planes, self.expansion*planes, n_models, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Rank1Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, n_models, stride=1):
        super(Rank1Bottleneck, self).__init__()
        self.conv1 = Rank1Conv2d(in_planes, planes, n_models, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = Rank1Conv2d(planes, planes, n_models, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = Rank1Conv2d(planes, self.expansion * planes, n_models, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                Rank1Conv2d(in_planes, self.expansion*planes, n_models, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Rank1ResNet(nn.Module):
    def __init__(self, block, n_blocks, n_models, num_classes=10):
        super(Rank1ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = Rank1Conv2d(3, 64, n_models, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, n_blocks[0], n_models, stride=1)
        self.layer2 = self._make_layer(block, 128, n_blocks[1], n_models, stride=2)
        self.layer3 = self._make_layer(block, 256, n_blocks[2], n_models, stride=2)
        self.layer4 = self._make_layer(block, 512, n_blocks[3], n_models, stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)
        self.n_models = n_models

    def _make_layer(self, block, planes, num_blocks, n_models, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, n_models, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out
