import torchvision
from torch.utils.data import DataLoader
from torchvision.models.feature_extraction import create_feature_extractor
import torch
import torch.nn as nn

class CoreModel(nn.Module):
    '''
    CNN backbone model
    '''
    def __init__(self, weights, readout_layer):
        super().__init__()

        resnet = torchvision.models.resnet50()
        checkpoint = torch.load(weights)

        new_checkpoint = {}
        for name in checkpoint['model'].keys():
            if name[:13] != 'module.model.':
                continue
            new_checkpoint[name[13:]] = checkpoint['model'][name]

        resnet.load_state_dict(new_checkpoint)
        resnet.eval()

        self.device = 'cuda'
        # torch feature extractor labels as relu, relu_1, relu_2 and torchlens uses relu_1, relu_2, relu_3
        self.readout_layer = readout_layer.replace('relu_3', 'relu_2')
        self.feature_extractor = create_feature_extractor(resnet, [self.readout_layer]).to(self.device)

        self.downsample = torchvision.transforms.Resize(224).to(self.device)
        self.normalize = torchvision.transforms.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225]))


    def forward(self, x):

        x = self.downsample(x)

        x = self.normalize(x)

        z = self.feature_extractor(x)[self.readout_layer]

        return z
