import torch
import torchvision
from transformers import BertModel


class BinaryPreTrainedNet(torch.nn.Module):

    def __init__(self, model_type):
        super(BinaryPreTrainedNet, self).__init__()
        self.activation = None
        if model_type == 'resnet':
            weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1
            self.model = torchvision.models.resnet50(weights=weights)
            self.model.fc = torch.nn.Linear(in_features=2048, out_features=1)
        elif model_type == 'densenet':
            weights=torchvision.models.DenseNet121_Weights.IMAGENET1K_V1
            self.model = torchvision.models.densenet121(weights=weights)
            self.model.classifier = torch.nn.Linear(in_features=1024, out_features=1, bias=True)
        else:
            raise ValueError('Model type ' + model_type + ' doesn\'t exist.')

        self.sigmoid = torch.nn.Sigmoid()
    
    def forward(self, x):
        x = self.model(x)
        x = self.sigmoid(x)
        return x

    def forward_logits(self, x):
        x = self.model(x)
        return x


class JswNet(torch.nn.Module):    

    def __init__(self):
        super(JswNet, self).__init__()
        self.linear_one = torch.nn.Linear(16, 1024)
        self.linear_two = torch.nn.Linear(1024, 1) 

    def forward(self, x):
        x = self.linear_one(x)
        x = torch.sigmoid(x)
        x = self.linear_two(x)
        x = torch.sigmoid(x)
        return x

    def forward_logits(self, x):
        x = self.linear_one(x)
        x = torch.sigmoid(x)
        x = self.linear_two(x)
        return x



class BertClassifier(torch.nn.Module):

    def __init__(self, hidden_dropout_prob=0, attention_probs_dropout_prob=0):
        super(BertClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('prajjwal1/bert-tiny', hidden_dropout_prob=hidden_dropout_prob, attention_probs_dropout_prob=attention_probs_dropout_prob)
        self.linear = torch.nn.Linear(128, 1)

    def forward(self, input_id, mask):
        _, x = self.bert(input_ids=input_id, attention_mask=mask, return_dict=False)
        x = self.linear(x)
        x = torch.sigmoid(x)
        return x
    
    def forward_logits(self, input_id, mask):
        _, x = self.bert(input_ids=input_id, attention_mask=mask, return_dict=False)
        x = self.linear(x)
        return x



class JswNetMediator(torch.nn.Module):

    def __init__(self, model_type):
        super(JswNetMediator, self).__init__()
        self.activation = None
        if model_type == 'resnet':
            weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1
            self.model = torchvision.models.resnet50(weights=weights)
            self.model.fc = torch.nn.Linear(in_features=2048, out_features=16)
        elif model_type == 'densenet':
            weights=torchvision.models.DenseNet121_Weights.IMAGENET1K_V1
            self.model = torchvision.models.densenet121(weights=weights)
            self.model.features = torch.nn.Linear(in_features=1024, out_features=16, bias=True)
        else:
            raise ValueError('Model type ' + model_type + ' doesn\'t exist.')

        self.sigmoid = torch.nn.Sigmoid()
    
    def forward(self, x):
        x = self.model(x)
        return x

    def forward_logits(self, x):
        x = self.model(x)
        return x
