from .utils import LayerBuilder
from .layers import *
import torch.nn as nn
import torch.nn.functional as F


class Bottleneck(nn.Module):
    def __init__(
        self,
        builder,
        inplanes,
        planes,
        expansion,
        stride=1,
        cardinality=1,
        downsample=None,
        fused_se=True,
        last_bn_0_init=False,
        trt=False,
    ):
        super(Bottleneck, self).__init__()
        self.conv1 = builder.conv(1, inplanes, planes)
        self.bn1 = builder.batchnorm(planes)
        self.conv2 = builder.conv(3, planes, planes, groups=cardinality, stride=stride)
        self.bn2 = builder.batchnorm(planes)
        self.conv3 = builder.conv(1, planes, planes * expansion)
        self.bn3 = builder.batchnorm(planes * expansion, zero_init=last_bn_0_init)
        self.relu = builder.activation()
        self.downsample = downsample
        self.stride = stride

        self.fused_se = fused_se

    def forward(self, x):
        residual = x

        out = self.conv1(batch)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual

        out = self.relu(out)

        return out, edge_index


class _resnet(nn.Module):
    def __init__(
            self,
            config,
            num_classes=1000,
            last_bn_0_init=False,
            conv_init="fan_in",
            trt=False,
            fused_se=True
    ):
        super(ResNetEFG, self).__init__()

        self.builder = LayerBuilder(config["layer_builder"])
        self.last_bn_0_init = last_bn_0_init
        self.conv1 = self.builder.conv(7, 3, config['stem_width'], stride=2)
        self.bn1 = self.builder.batchnorm(config['stem_width'])
        self.relu = self.builder.activation()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        inplanes = config['stem_width']
        assert len(config['widths']) == len(config['layers'])
        self.num_layers = len(config['widths'])
        layers = []
        for i, (w, l) in enumerate(zip(config['widths'], config['layers'])):
            layer, inplanes = self._make_layer(
                Bottleneck,
                config['expansion'],
                inplanes,
                w,
                l,
                cardinality=config['cardinality'],
                stride=1 if i == 0 else 2,
                trt=trt,
                fused_se=fused_se,
            )
            layers.append(layer)

        self.layers = nn.Sequential(*layers)
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
        self.fc = nn.Linear(config['widths'][-1] * config['expansion'], num_classes)

    def stem(self, x):
        x = self.conv1(x)
        if self.bn1 is not None:
            x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        return x

    def classifier(self, x):
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    def forward(self, batch):
        # This is just so that I don't have to create a separate Pytorch dataset
        # Basically we just extract the sample of nodes as a batch dimension
        # This does add the possibility of contextual sampling
        x = batch.x
        x = self.stem(x)
        x = self.layers(x)
        x = self.classifier(x)
        return x

    def _make_layer(
        self,
        block,
        expansion,
        inplanes,
        planes,
        blocks,
        stride=1,
        cardinality=1,
        trt=False,
        fused_se=True,
    ):
        downsample = None

        if stride != 1 or inplanes != planes * expansion:
            dconv = self.builder.conv(1, inplanes, planes * expansion, stride=stride)
            dbn = self.builder.batchnorm(planes * expansion)
            if dbn is not None:
                downsample = nn.Sequential(dconv, dbn)
            else:
                downsample = dconv

        layers = []
        for i in range(blocks):
            layers.append(
                block(
                    self.builder,
                    inplanes,
                    planes,
                    expansion,
                    stride=stride if i == 0 else 1,
                    cardinality=cardinality,
                    downsample=downsample if i == 0 else None,
                    fused_se=fused_se,
                    last_bn_0_init=self.last_bn_0_init,
                    trt=trt,
                )
            )
            inplanes = planes * expansion

        return nn.Sequential(*layers), inplanes


class ResNet50(nn.module):
    def __init__(self, config, num_classes):
        weights = ResNet50_Weights.IMAGENET1K_V2
        model = _resnet(config, num_classes)
        breakpoint()
        model.load_state_dict(weights.get_state_dict(progress=progress))
        return model

