import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchvision import transforms


pretrain_path = 'algorithms/pretrain'


def backbone(hparams):
    if hparams['dataset'] in ['mimiccxr', 'areds']:
        return Image2DBacknone(hparams)
    elif hparams['dataset'] in ['adni', 'ukb-mi']:
        return Image3DBackbone(hparams)
    else:
        raise NotImplementedError


class Image2DBacknone(nn.Module):
    def __init__(self, hparams):
        super(Image2DBacknone, self).__init__()
        self.hparams = hparams
        if hparams['pretrained']:
            self.network = nn.Sequential(torchvision.models.efficientnet_b0(weights=torchvision.models.EfficientNet_B0_Weights.IMAGENET1K_V1))
        else:
            self.network = nn.Sequential(torchvision.models.efficientnet_b0())

        outdim = self.network[0].classifier[1].in_features
        self.network.add_module('fc', nn.Linear(outdim, self.hparams['feature_dim']))

        del self.network[0].classifier
        self.network[0].classifier = nn.Identity()
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        if self.hparams['pretrained']:
            x = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(x)
        return self.dropout(self.network(x))


class Image3DBackbone(nn.Module):
    def __init__(self, hparams):
        super(Image3DBackbone, self).__init__()
        self.hparams = hparams
        if hparams['pretrained']:
            self.network = nn.Sequential(torchvision.models.video.r3d_18(pretrained=True), nn.Linear(512, self.hparams['feature_dim']))
        else:
            self.network = nn.Sequential(torchvision.models.video.r3d_18(pretrained=False), nn.Linear(512, self.hparams['feature_dim']))
        if self.hparams['n_channel'] != 3:
            self.network[0].stem[0] = nn.Conv3d(self.hparams['n_channel'], 64, kernel_size=(7, 7, 7), stride=(2, 2, 2), padding=(3, 3, 3), bias=False)

        # save memory
        del self.network[0].fc
        self.network[0].fc = nn.Identity()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        return self.dropout(self.network(x))
