import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.feature_extraction import create_feature_extractor
from importlib import import_module
import numpy as np


class pattern_norm(nn.Module):
    def __init__(self, scale = 1.0):
        super(pattern_norm, self).__init__()
        self.scale = scale

    def forward(self, input):
        sizes = input.size()
        if len(sizes) > 2:
            input = input.view(-1, np.prod(sizes[1:]))
            input = torch.nn.functional.normalize(input, p=2, dim=1, eps=1e-12)
            input = input.view(sizes)
        return input


def EnDNet(backbone, n_classes, pretrained = True, ssl_pretrained = False):
    # todo body did not go through the pattern norm
    mod = import_module("models.basemodels")
    cusModel = getattr(mod, backbone)
    model = cusModel(n_classes=n_classes, pretrained=pretrained, ssl_pretrained = ssl_pretrained)
    model.body.avgpool = nn.Sequential(
        model.avgpool,
        pattern_norm()
    )
    return model


def EnDNet3D(backbone, n_classes, pretrained = True):
    # todo body did not go through the pattern norm
    mod = import_module("models.basemodels_3d")
    cusModel = getattr(mod, backbone)
    model = cusModel(n_classes=n_classes, pretrained=pretrained)
    model.body.avgpool = nn.Sequential(
        model.avgpool,
        pattern_norm()
    )
    return model