import torchvision.models as models
from torch import nn as nn
import torch.nn.functional as F
import torch
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_V2_Weights

class ResNet18PenultimateFeatureExtractor(nn.Module):
    def __init__(self, projection_size=512, device=None):
        """
        Initializes the ResNet18 model and modifies it to output penultimate layer features.

        Args:
            device (str, optional): Device to load the model on ('cuda' or 'cpu').
                                    Defaults to CUDA if available, else CPU.
        """
        super(ResNet18PenultimateFeatureExtractor, self).__init__()

        # Set device
        self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')

        # Load the pretrained ResNet18 model
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        resnet = resnet.to(self.device)
        resnet.eval()  # Set to evaluation mode

        # Remove the final fully connected layer
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        self.features = self.features.to(self.device)
        self.features.eval()  # Ensure feature extractor is in eval mode

        # Define ImageNet normalization parameters
        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device)
        self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device)

        # ResNet 18 Penultimate layer already has size 512
        self.projection_size = projection_size
        if self.projection_size != 512:
            flattened_size = 512
            # Initialize the random projection layer
            self.projection = nn.Linear(flattened_size, projection_size, bias=False)
            self.projection = self.projection.to(self.device)

            # Initialize projection weights randomly
            nn.init.kaiming_normal_(self.projection.weight, mode='fan_out', nonlinearity='relu')

            # Freeze the projection layer parameters
            for param in self.projection.parameters():
                param.requires_grad = False

    def forward(self, x):
        """
        Processes the input tensor and extracts penultimate layer features.

        Args:
            x (torch.Tensor): Input tensor of shape [batch, 1, 84, 84].

        Returns:
            torch.Tensor: Extracted features of shape [batch, 512].
        """
        # Move input to the appropriate device
        x = x.to(self.device)

        # 1. Duplicate the single channel to create a 3-channel image
        x = x.repeat(1, 3, 1, 1)  # Shape: [batch, 3, 84, 84]

        # 2. Resize the image to 224x224
        x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)

        # 3. Normalize the image using ImageNet's mean and std
        x = (x - self.mean) / self.std

        # 4. Extract features using ResNet18's penultimate layer
        with torch.no_grad():
            features = self.features(x)  # Shape: [batch, 512, 1, 1]
            features = features.view(features.size(0), -1)  # Shape: [batch, 512]

        if self.projection_size != 512:
            # 6. Apply the non-trainable random projection
            features = self.projection(features)  # Shape: [batch, projection_size]

        return features


class ResNet18FullRes(nn.Module):
    def __init__(self, projection_size=512, device=None):
        """
        Initializes the ResNet18 model and modifies it to output penultimate layer features,
        accepting a full-resolution (210x160 RGB) image.

        Args:
            projection_size (int, optional): Desired output size; defaults to 512.
            device (str, optional): Device to load the model on ('cuda' or 'cpu').
                                    Defaults to CUDA if available, else CPU.
        """
        super(ResNet18FullRes, self).__init__()

        # Set device
        self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')

        # Load the pretrained ResNet18 model
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        resnet = resnet.to(self.device)
        resnet.eval()  # Set to evaluation mode

        # Remove the final fully connected layer
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        self.features = self.features.to(self.device)
        self.features.eval()  # Ensure feature extractor is in eval mode

        # Define ImageNet normalization parameters (for 3 channels)
        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device)
        self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device)

        # ResNet 18 penultimate layer outputs 512 features
        self.projection_size = projection_size
        if self.projection_size != 512:
            flattened_size = 512
            self.projection = nn.Linear(flattened_size, projection_size, bias=False)
            self.projection = self.projection.to(self.device)
            nn.init.kaiming_normal_(self.projection.weight, mode='fan_out', nonlinearity='relu')
            for param in self.projection.parameters():
                param.requires_grad = False

    def forward(self, x):
        """
        Processes the input tensor and extracts penultimate layer features.

        Args:
            x (torch.Tensor): Input tensor of shape [batch, 3, 210, 160].

        Returns:
            torch.Tensor: Extracted features of shape [batch, 512] (or [batch, projection_size] if projection is applied).
        """
        # Move input to the appropriate device
        x = x.to(self.device)
        # Resize the image from 210x160 to 224x224
        x = x.permute(0, 3, 1, 2)
        x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
        # Normalize the image using ImageNet's mean and std

        x = (x - self.mean) / self.std

        # Extract features using ResNet18's penultimate layer
        with torch.no_grad():
            features = self.features(x)  # Shape: [batch, 512, 1, 1]
            features = features.view(features.size(0), -1)  # Shape: [batch, 512]

        if self.projection_size != 512:
            features = self.projection(features)  # Shape: [batch, projection_size]

        return features


class FasterRCNNResNet50FPNPenultimateFeatureExtractor(nn.Module):
    def __init__(self, projection_size=512, device=None):
        """
        Initializes the Faster R-CNN ResNet50 FPN model and modifies it to output penultimate layer features.

        Args:
            projection_size (int, optional): Size of the projection layer. Defaults to 256.
            device (str, optional): Device to load the model on ('cuda' or 'cpu').
                                    Defaults to CUDA if available, else CPU.

        This models 'Pool' layer has shape 4096
        Models does not use linear layers
        """
        super(FasterRCNNResNet50FPNPenultimateFeatureExtractor, self).__init__()

        # Set device
        self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')

        # Load the pretrained Faster R-CNN model
        weights = FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1
        model = models.detection.fasterrcnn_resnet50_fpn_v2(weights=weights)
        model = model.to(self.device)
        model.eval()  # Set to evaluation mode

        # Extract the backbone from the Faster R-CNN model
        self.backbone = model.backbone
        self.backbone = self.backbone.to(self.device)
        self.backbone.eval()

        # Define ImageNet normalization parameters
        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device)
        self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device)

        # Set projection size
        self.projection_size = projection_size
        if self.projection_size != 4096:
            flattened_size = 4096
            # Initialize the random projection layer
            self.projection = nn.Linear(flattened_size, projection_size, bias=False)
            self.projection = self.projection.to(self.device)

            # Initialize projection weights randomly
            nn.init.kaiming_normal_(self.projection.weight, mode='fan_out', nonlinearity='relu')

            # Freeze the projection layer parameters
            for param in self.projection.parameters():
                param.requires_grad = False

    def forward(self, x):
        """
        Processes the input tensor and extracts penultimate layer features.

        Args:
            x (torch.Tensor): Input tensor of shape [batch, 1, 84, 84].

        Returns:
            torch.Tensor: Extracted features of shape [batch, projection_size].
        """
        # Move input to the appropriate device
        x = x.to(self.device)

        # 1. Duplicate the single channel to create a 3-channel image
        x = x.repeat(1, 3, 1, 1)  # Shape: [batch, 3, 84, 84]

        # 2. Resize the image to 224x224
        x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)

        # 3. Normalize the image using ImageNet's mean and std
        x = (x - self.mean) / self.std

        # 4. Extract features using the backbone
        with torch.no_grad():
            features = self.backbone(x)  # features is a dict of feature maps

        # 5. Select the 'out' feature map from FPN
        features = features['pool']  # Shape: [batch, 4096, H, W]
        # 6. Apply global average pooling to get a fixed-size embedding
        features = features.view(features.size(0), -1)  # Shape: [batch, 4096]

        # 7. Apply the projection layer if needed
        if self.projection_size != 4096:
            features = self.projection(features)  # Shape: [batch, projection_size]

        return features

class EfficientNetV2PenultimateFeatureExtractor(nn.Module):
    def __init__(self, projection_size=1280, device=None):
        """
        Initializes the EfficientNetV2 small model and modifies it to output penultimate layer features.

        Args:
            projection_size (int): Desired size of the output feature vector. For EfficientNetV2-S,
                                   the penultimate features are naturally 1280-dim.
            device (str, optional): Device to load the model on ('cuda' or 'cpu').
                                    Defaults to CUDA if available, else CPU.
        """
        super(EfficientNetV2PenultimateFeatureExtractor, self).__init__()

        # Set device
        self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')

        # Load the pretrained EfficientNetV2 small model
        efficientnet = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1)
        efficientnet = efficientnet.to(self.device)
        efficientnet.eval()  # Set to evaluation mode

        # Remove the final classifier to obtain penultimate layer features
        # The classifier is the final module in the model, so we take all preceding layers.
        self.features = nn.Sequential(*list(efficientnet.children())[:-1])
        self.features = self.features.to(self.device)
        self.features.eval()  # Ensure feature extractor is in eval mode

        # Define ImageNet normalization parameters
        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device)
        self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device)

        # EfficientNetV2-S penultimate features are 1280-dimensional.
        self.projection_size = projection_size
        if self.projection_size != 1280:
            flattened_size = 1280
            # Initialize a non-trainable random projection layer
            self.projection = nn.Linear(flattened_size, projection_size, bias=False)
            self.projection = self.projection.to(self.device)
            nn.init.kaiming_normal_(self.projection.weight, mode='fan_out', nonlinearity='relu')

            # Freeze the projection layer parameters
            for param in self.projection.parameters():
                param.requires_grad = False

    def forward(self, x):
        """
        Processes the input tensor and extracts penultimate layer features.

        Args:
            x (torch.Tensor): Input tensor of shape [batch, 1, 84, 84].

        Returns:
            torch.Tensor: Extracted features of shape [batch, 1280] (or [batch, projection_size] if projected).
        """
        # Move input to the appropriate device
        x = x.to(self.device)

        # 1. Duplicate the single channel to create a 3-channel image
        x = x.repeat(1, 3, 1, 1)  # Shape: [batch, 3, 84, 84]

        # 2. Resize the image to 224x224
        x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)

        # 3. Normalize the image using ImageNet's mean and std
        x = (x - self.mean) / self.std

        # 4. Extract features using EfficientNetV2-S's penultimate layer
        with torch.no_grad():
            features = self.features(x)  # Expected shape: [batch, 1280, 1, 1]
            features = features.view(features.size(0), -1)  # Shape: [batch, 1280]

        if self.projection_size != 1280:
            # Apply the non-trainable random projection if desired
            features = self.projection(features)  # Shape: [batch, projection_size]

        return features
