import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import torchvision
import torchvision.models as models
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor


#can be seen as a special case of FiLM
class FeatureWeighting(nn.Module):
    def __init__(self, base_model, temperature = 1.0):
        super(FeatureWeighting, self).__init__()

        self.base_model_name = base_model

        self.temperature = temperature
                 
        if base_model == 'resnet18':
            self.base_model = torchvision.models.resnet18(norm_layer = nn.InstanceNorm2d)
            self.transforms = models.resnet.ResNet18_Weights.IMAGENET1K_V1.transforms
            self.base_model = create_feature_extractor(self.base_model, {'flatten':'feature'})
            self.base_model.eval()
            out = self.base_model(torch.rand(1, 3, 224, 224))
            self.base_model.train()
            
            self.output_dim = out['feature'].size(1)
            self.embed_dim = self.output_dim

        elif base_model == 'resnet18_pretrained':
            self.base_model = torchvision.models.resnet18(weights = 'IMAGENET1K_V1')
            self.transforms = models.resnet.ResNet18_Weights.IMAGENET1K_V1.transforms
            self.base_model = create_feature_extractor(self.base_model, {'flatten':'feature'})
            self.base_model.eval()
            out = self.base_model(torch.rand(1, 3, 224, 224))
            self.base_model.train()
            
            self.output_dim = out['feature'].size(1)
            self.embed_dim = self.output_dim
            
        elif base_model == 'Swin_T_pretrained':
            self.base_model = torchvision.models.swin_t(weights = 'IMAGENET1K_V1')
            self.transforms = models.resnet.ResNet18_Weights.IMAGENET1K_V1.transforms
            self.base_model = create_feature_extractor(self.base_model, {'flatten':'feature'})
            self.base_model.eval()
            out = self.base_model(torch.rand(1, 3, 224, 224))
            self.base_model.train()
            
            self.output_dim = out['feature'].size(1)
            self.embed_dim = self.output_dim
        
        elif base_model == 'resnet18_BN':
            self.base_model = torchvision.models.resnet18()
            self.transforms = models.resnet.ResNet18_Weights.IMAGENET1K_V1.transforms
            self.base_model = create_feature_extractor(self.base_model, {'flatten':'feature'})
            self.base_model.eval()
            out = self.base_model(torch.rand(1, 3, 224, 224))
            self.base_model.train()
            
            self.output_dim = out['feature'].size(1)
            self.embed_dim = self.output_dim
        
        elif base_model == 'mobilenet_v3_small_IN':
            self.base_model = torchvision.models.mobilenet_v3_small(norm_layer = nn.InstanceNorm2d)
            self.base_model = create_feature_extractor(self.base_model, {'flatten':'feature'})
            self.base_model.eval()
            out = self.base_model(torch.rand(1, 3, 224, 224))
            self.base_model.train()
            
            self.output_dim = out['feature'].size(1)
            self.embed_dim = self.output_dim

        else:
            raise Exception('base_model not implemented')
        
        #print (get_graph_node_names(self.base_model))

    def set_backbone(self, requires_grad):
        for param in self.base_model.parameters():
            param.requires_grad = requires_grad

    def init_embeddings(self, n_model):
        self.embeddings = Parameter(torch.zeros(n_model, self.embed_dim))
    
    def embedding_parameters(self):
        return [self.embeddings]

    def get_embeddings(self, idx):
        if hasattr(self, 'temperature'):
            temperature = self.temperature
        else:
            temperature = 1.0
        return nn.functional.sigmoid(self.embeddings[idx] / temperature)

    def init_selection_embedding(self):
        self.selection_embedding = Parameter(torch.zeros(self.embed_dim))

    def selection_parameters(self):
        return [self.selection_embedding]
    
    def get_selection_embedding(self):
        if hasattr(self, 'temperature'):
            temperature = self.temperature
        else:
            temperature = 1.0
        return nn.functional.sigmoid(self.selection_embedding / temperature)

    def forward(self, x: torch.Tensor, v: torch.Tensor, selection: bool = False) -> torch.Tensor:
        if selection:
            with torch.no_grad():
                base_feature = self.base_model(x)['feature']
            
            return base_feature * v
        else:
            return self.base_model(x)['feature'] * v
    
        
