"""
Image encoders and classifier models.

This module defines the core model structures for vision-based classifiers,
including image encoders and classification heads.
"""

import open_clip
import torch
from src import utils

class ImageEncoder(torch.nn.Module):
    """Image encoder that uses pre-trained CLIP models

    Args:
        args: Arguments containing model configuration
        keep_lang: Whether to keep the language model part
    """
    def __init__(self, args, keep_lang=False):
        super().__init__()

        print(f"Loading {args.model} pre-trained weights.")
        if "__pretrained__" in args.model:
            name, pretrained = args.model.split("__pretrained__")
        elif "__init__" in args.model:
            print("Using random initialization.")
            name, pretrained = args.model.split("__init__")[0], None
        else:
            name = args.model
            pretrained = "openai"

        # Create model and transforms
        (
            self.model,
            self.train_preprocess,
            self.val_preprocess,
        ) = open_clip.create_model_and_transforms(
            name,
            pretrained=pretrained,
            cache_dir=getattr(args, "openclip_cachedir", None)
        )

        self.cache_dir = getattr(args, "cache_dir", None)

        # Remove language transformer if not needed
        if not keep_lang and hasattr(self.model, "transformer"):
            delattr(self.model, "transformer")

    def forward(self, images):
        """Encode images to feature vectors"""
        assert self.model is not None
        return self.model.encode_image(images)

    def __call__(self, inputs):
        return self.forward(inputs)

    def save(self, filename):
        """Save encoder to file"""
        utils.torch_save(self, filename)

    @classmethod
    def load(cls, model_name, filename):
        """Load encoder from file"""
        state_dict = torch.load(filename, map_location="cpu")
        return cls.load(model_name, state_dict)

class ClassificationHead(torch.nn.Linear):
    """Linear classification head

    Args:
        normalize: Whether to normalize input features
        weights: Initial weights for the linear layer
        biases: Initial biases for the linear layer
    """
    def __init__(self, normalize, weights, biases=None):
        output_size, input_size = weights.shape
        super().__init__(input_size, output_size)
        self.normalize = normalize

        # Initialize weights and biases
        if weights is not None:
            self.weight = torch.nn.Parameter(weights.clone())
        if biases is not None:
            self.bias = torch.nn.Parameter(biases.clone())
        else:
            self.bias = torch.nn.Parameter(torch.zeros_like(self.bias))

    def forward(self, inputs):
        """Forward pass with optional normalization"""
        if self.normalize:
            inputs = inputs / inputs.norm(dim=-1, keepdim=True)
        return super().forward(inputs)

    def __call__(self, inputs):
        return self.forward(inputs)

    def save(self, filename):
        """Save classification head to file"""
        utils.torch_save(self, filename)

    @classmethod
    def load(cls, filename):
        """Load classification head from file"""
        return utils.torch_load(filename)

class ImageClassifier(torch.nn.Module):
    """Complete image classifier with encoder and classification head

    Args:
        image_encoder: Image encoder module
        classification_head: Classification head module
    """
    def __init__(self, image_encoder, classification_head):
        super().__init__()
        self.image_encoder = image_encoder
        self.classification_head = classification_head

        # Set preprocessing from encoder
        if self.image_encoder is not None:
            self.train_preprocess = self.image_encoder.train_preprocess
            self.val_preprocess = self.image_encoder.val_preprocess

    def freeze_head(self):
        """Freeze classification head parameters"""
        self.classification_head.weight.requires_grad_(False)
        self.classification_head.bias.requires_grad_(False)

    def forward(self, inputs, return_features=False):
        """Forward pass with optional feature return

        Args:
            inputs: Input images
            return_features: Whether to also return features

        Returns:
            Classification outputs (and normalized features if requested)
        """
        features = self.image_encoder(inputs)
        outputs = self.classification_head(features)

        if return_features:
            return outputs, features / features.norm(dim=-1, keepdim=True)
        return outputs

    def __call__(self, inputs, **kwargs):
        return self.forward(inputs, **kwargs)

    def save(self, filename):
        """Save classifier to file"""
        utils.torch_save(self, filename)

    @classmethod
    def load(cls, filename):
        """Load classifier from file"""
        return utils.torch_load(filename)