import torch.nn as nn
import warnings
import torch
import torchvision.models as models
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from clip import load
from torchvision.models import vit_b_16, ViT_B_16_Weights

warnings.filterwarnings("ignore")




class Net(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size) 
        self.relu = nn.ReLU()
        self.fc1_drop = nn.Dropout(0.0)
        self.fc2 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        representation = self.fc1_drop(out)
        logit = self.fc2(representation)
        return logit, representation




class FC(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(FC, self).__init__()
        self.fc1 = nn.Linear(hidden_size, hidden_size) 
        self.relu = nn.ReLU()
        self.fc1_drop = nn.Dropout(0.0)
        self.fc2 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        representation = self.fc1_drop(out)
        logit = self.fc2(representation)
        return logit
    


class FC_3(nn.Module):
    def __init__(self, hidden_size, neural_size_1, neural_size_2, num_classes):
        super(FC_3, self).__init__()
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, neural_size_1)
        self.fc2_drop = nn.Dropout(0.2)
        self.fc3 = nn.Linear(neural_size_1,neural_size_2)
        self.fc3_drop = nn.Dropout(0.0)
        self.last_layer = nn.Linear(neural_size_2, num_classes)  
    
    def forward(self, x):
        out = self.fc2(x)
        out = self.relu(out)
        out = self.fc3(out)
        out = self.relu(out)
        out = self.fc3_drop(out)
        logit = self.last_layer(out)
        return logit
    

class FC_4(nn.Module):
    def __init__(self, hidden_size, neural_size_1, neural_size_2, num_classes):
        super(FC_4, self).__init__()
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, neural_size_1)
        self.fc2_drop = nn.Dropout(0.0)
        self.fc3 = nn.Linear(neural_size_1,neural_size_2)
        self.fc3_drop = nn.Dropout(0.2)
        self.fc4 = nn.Linear(neural_size_2, neural_size_2)  
        self.last_layer = nn.Linear(neural_size_2, num_classes)
    
    def forward(self, x):
        out = self.fc2(x)
        out = self.relu(out)
        out = self.fc2_drop(out)
        out = self.fc3(out)
        out = self.relu(out)
        out = self.fc3_drop(out)
        out = self.fc4(out)
        out = self.relu(out)
        out = self.fc4_drop
        logit = self.last_layer(out)
        return logit
    

class FC_5(nn.Module):
    def __init__(self, hidden_size, neural_size_1, neural_size_2, num_classes):
        super(FC_5, self).__init__()
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, neural_size_1)
        self.fc2_drop = nn.Dropout(0.2)
        self.fc3 = nn.Linear(neural_size_1,neural_size_2)
        self.fc3_drop = nn.Dropout(0.2)
        self.fc4 = nn.Linear(neural_size_2, neural_size_2)  
        self.fc4_drop = nn.Dropout(0.05)
        self.fc5 = nn.Linear(neural_size_2, neural_size_2-3)
        self.last_layer = nn.Linear(neural_size_2-3, num_classes)
    
    def forward(self, x):
        out = self.fc2(x)
        out = self.relu(out)
        out = self.fc3(out)
        out = self.relu(out)
        out = self.fc4(out)
        out = self.relu(out)
        out = self.fc5(out)
        out = self.relu(out)
        logit = self.last_layer(out)
        return logit
    


class ResNetFeatureExtractor(nn.Module):
    def __init__(self):
        super(ResNetFeatureExtractor, self).__init__()
        resnet18 = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*list(resnet18.children())[:-1])
        for param in self.features.parameters():
            param.requires_grad = False
        self.fc = nn.Linear(512,2)

    def forward(self, x):
        rep = self.features(x).view(x.size(0),-1)
        out = self.fc(rep)
        return out, rep



class CLIPRN50FeatureExtractor(nn.Module):
    def __init__(self, device = "cuda"):
        super(CLIPRN50FeatureExtractor, self).__init__()
        self.clip_model, self.preprocess = load("RN50", device=device, jit=False)
        for param in self.clip_model.parameters():
            param.requires_grad = False
        self.fc = nn.Linear(1024, 2) 
        self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                              std=[0.26862954, 0.26130258, 0.27577711])
        self.device = device

    def forward(self, x):
        x = x.to(self.device)
        x = nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
        x = self.normalize(x)
        x = x.to(self.clip_model.dtype)

        with torch.no_grad():
            features = self.clip_model.encode_image(x)
        features = features.float()
        out = self.fc(features)

        return out, features



class CLIPViTB16FeatureExtractor(nn.Module):
    def __init__(self):
        super(CLIPViTB16FeatureExtractor, self).__init__()
   
        self.vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        self.vit.heads = nn.Identity()
        self.fc = nn.Linear(768,2)

        for param in self.vit.parameters():
            param.requires_grad = False
        
        self.preprocess = transforms.Compose([
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225])
        ])

    def forward(self, x):
        x = nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
        x = self.preprocess(x)
        features = self.vit(x)
        out = self.fc(features)
        
        return out, features


class CNNModel(nn.Module):
    def __init__(self, num_classes=10):
        super(CNNModel, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(128 * 20 * 20, 50)
        self.fc2 = nn.Linear(50, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 20 * 20)
        rep = F.relu(self.fc1(x))
        x_ = self.dropout(rep)
        out = self.fc2(x_)
        
        return out, rep



class GenderCNN(nn.Module):
    def __init__(self):
        super(GenderCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=64 * 20 * 20, out_features=512)
        self.fc2 = nn.Linear(in_features=512, out_features=2)


    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = nn.functional.relu(x)
        x = self.pool3(x)
        x = x.reshape(-1, 64 * 20 * 20)
        x = self.fc1(x)
        rep = nn.functional.relu(x)
        out = self.fc2(rep)
        return out, rep

    



class LogisticRegressionModel(nn.Module):
    def __init__(self, hid_size):
        super(LogisticRegressionModel, self).__init__()
        self.linear = nn.Linear(hid_size, 2)
        
    def forward(self, x):
        out = self.linear(x)
        return out
