import torch
import torch.nn as nn
import copy
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152


class ResNetEncoder(nn.Module):
    def __init__(self, cfg, in_size=224, hierarchies=False):
        super().__init__()
        self.resnet = globals()[cfg.resnet_type.name]
        if cfg.pretrain:
            self.resnet = self.resnet(weights="DEFAULT")
        else:
            self.resnet = self.resnet()
        self.out_dim = self.resnet.fc.in_features
        if in_size == 32:
            self.resnet.conv1 = nn.Conv2d(
                self.resnet.conv1.in_channels,
                self.resnet.conv1.out_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            )
            self.resnet.maxpool = nn.Identity()
        self.hierarchies = hierarchies
        if hierarchies:
            self.hierarchy1 = nn.Sequential(
                copy.deepcopy(self.resnet.layer2),
                copy.deepcopy(self.resnet.layer3),
                copy.deepcopy(self.resnet.layer4),
                self.resnet.avgpool,
                nn.Flatten(),
            )
            self.hierarchy2 = nn.Sequential(
                copy.deepcopy(self.resnet.layer3),
                copy.deepcopy(self.resnet.layer4),
                self.resnet.avgpool,
                nn.Flatten(),
            )
            self.hierarchy3 = nn.Sequential(
                copy.deepcopy(self.resnet.layer4),
                self.resnet.avgpool,
                nn.Flatten(),
            )

    def forward(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        out1 = self.resnet.layer1(x)
        out2 = self.resnet.layer2(out1)
        out3 = self.resnet.layer3(out2)
        out4 = self.resnet.layer4(out3)

        out = self.resnet.avgpool(out4)
        out = torch.flatten(out, 1)
        if self.hierarchies:
            return [
                self.hierarchy1(out1),
                self.hierarchy2(out2),
                self.hierarchy3(out3),
                out,
            ]
        return out
