#include "resnet.h"

namespace vision {
namespace models {
namespace _resnetimpl {
torch::nn::Conv2d conv3x3(
    int64_t in,
    int64_t out,
    int64_t stride,
    int64_t groups) {
  torch::nn::Conv2dOptions O(in, out, 3);
  O.padding(1).stride(stride).groups(groups).bias(false);
  return torch::nn::Conv2d(O);
}

torch::nn::Conv2d conv1x1(int64_t in, int64_t out, int64_t stride) {
  torch::nn::Conv2dOptions O(in, out, 1);
  O.stride(stride).bias(false);
  return torch::nn::Conv2d(O);
}

int BasicBlock::expansion = 1;
int Bottleneck::expansion = 4;

BasicBlock::BasicBlock(
    int64_t inplanes,
    int64_t planes,
    int64_t stride,
    const torch::nn::Sequential& downsample,
    int64_t groups,
    int64_t base_width)
    : stride(stride), downsample(downsample) {
  TORCH_CHECK(
      groups == 1 && base_width == 64,
      "BasicBlock only supports groups=1 and base_width=64");

  // Both conv1 and downsample layers downsample the input when stride != 1
  conv1 = conv3x3(inplanes, planes, stride);
  conv2 = conv3x3(planes, planes);

  bn1 = torch::nn::BatchNorm2d(planes);
  bn2 = torch::nn::BatchNorm2d(planes);

  register_module("conv1", conv1);
  register_module("conv2", conv2);

  register_module("bn1", bn1);
  register_module("bn2", bn2);

  if (!downsample.is_empty())
    register_module("downsample", this->downsample);
}

Bottleneck::Bottleneck(
    int64_t inplanes,
    int64_t planes,
    int64_t stride,
    const torch::nn::Sequential& downsample,
    int64_t groups,
    int64_t base_width)
    : stride(stride), downsample(downsample) {
  auto width = int64_t(planes * (base_width / 64.)) * groups;

  // Both conv2 and downsample layers downsample the input when stride != 1
  conv1 = conv1x1(inplanes, width);
  conv2 = conv3x3(width, width, stride, groups);
  conv3 = conv1x1(width, planes * expansion);

  bn1 = torch::nn::BatchNorm2d(width);
  bn2 = torch::nn::BatchNorm2d(width);
  bn3 = torch::nn::BatchNorm2d(planes * expansion);

  register_module("conv1", conv1);
  register_module("conv2", conv2);
  register_module("conv3", conv3);

  register_module("bn1", bn1);
  register_module("bn2", bn2);
  register_module("bn3", bn3);

  if (!downsample.is_empty())
    register_module("downsample", this->downsample);
}

torch::Tensor Bottleneck::forward(torch::Tensor X) {
  auto identity = X;

  auto out = conv1->forward(X);
  out = bn1->forward(out).relu_();

  out = conv2->forward(out);
  out = bn2->forward(out).relu_();

  out = conv3->forward(out);
  out = bn3->forward(out);

  if (!downsample.is_empty())
    identity = downsample->forward(X);

  out += identity;
  return out.relu_();
}

torch::Tensor BasicBlock::forward(torch::Tensor x) {
  auto identity = x;

  auto out = conv1->forward(x);
  out = bn1->forward(out).relu_();

  out = conv2->forward(out);
  out = bn2->forward(out);

  if (!downsample.is_empty())
    identity = downsample->forward(x);

  out += identity;
  return out.relu_();
}
} // namespace _resnetimpl

ResNet18Impl::ResNet18Impl(int64_t num_classes, bool zero_init_residual)
    : ResNetImpl({2, 2, 2, 2}, num_classes, zero_init_residual) {}

ResNet34Impl::ResNet34Impl(int64_t num_classes, bool zero_init_residual)
    : ResNetImpl({3, 4, 6, 3}, num_classes, zero_init_residual) {}

ResNet50Impl::ResNet50Impl(int64_t num_classes, bool zero_init_residual)
    : ResNetImpl({3, 4, 6, 3}, num_classes, zero_init_residual) {}

ResNet101Impl::ResNet101Impl(int64_t num_classes, bool zero_init_residual)
    : ResNetImpl({3, 4, 23, 3}, num_classes, zero_init_residual) {}

ResNet152Impl::ResNet152Impl(int64_t num_classes, bool zero_init_residual)
    : ResNetImpl({3, 8, 36, 3}, num_classes, zero_init_residual) {}

ResNext50_32x4dImpl::ResNext50_32x4dImpl(
    int64_t num_classes,
    bool zero_init_residual)
    : ResNetImpl({3, 4, 6, 3}, num_classes, zero_init_residual, 32, 4) {}

ResNext101_32x8dImpl::ResNext101_32x8dImpl(
    int64_t num_classes,
    bool zero_init_residual)
    : ResNetImpl({3, 4, 23, 3}, num_classes, zero_init_residual, 32, 8) {}

WideResNet50_2Impl::WideResNet50_2Impl(
    int64_t num_classes,
    bool zero_init_residual)
    : ResNetImpl({3, 4, 6, 3}, num_classes, zero_init_residual, 1, 64 * 2) {}

WideResNet101_2Impl::WideResNet101_2Impl(
    int64_t num_classes,
    bool zero_init_residual)
    : ResNetImpl({3, 4, 23, 3}, num_classes, zero_init_residual, 1, 64 * 2) {}

} // namespace models
} // namespace vision
