from models.base_model import ShapesBaseModel
from typing import Tuple, Optional
from transformers import CLIPVisionModel
import pytorch_lightning as pl
import torch


class CLIPLinearEval(ShapesBaseModel):
    """Trains a fresh linear classifier while keeping other weights frozen."""

    def __init__(
        self,
        learning_rate: float = 1e-1,
        optimizer: str = "adam",
        momentum: float = 0.9,
        weight_decay: float = 1e-4,
        top_k: Tuple[int, ...] = (1, 10),
        datamodule: Optional[pl.LightningDataModule] = None,
    ):
        super().__init__(top_k=top_k, datamodule=datamodule)

        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.momentum = momentum
        self.weight_decay = weight_decay

        self.backbone = self.load_backbone()
        self.feature_dim = 768
        self.linear_classifier = torch.nn.Linear(self.feature_dim, self.num_classes)

    def load_backbone(self):
        model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
        return model

    def on_train_epoch_start(self) -> None:
        self.backbone.eval()

    def forward(self, x):
        with torch.no_grad():
            # model expects a dictionary with pixel_values -> tensor
            expected_input = {"pixel_values": x}
            # based on HuggingFace API for extracting features
            feats = self.backbone(**expected_input).pooler_output
        out = self.linear_classifier(feats)
        return out


class CLIPFinetuner(ShapesBaseModel):
    """Trains a fresh linear classifier while keeping other weights frozen."""

    def __init__(
        self,
        learning_rate: float = 1e-1,
        optimizer: str = "adam",
        momentum: float = 0.9,
        weight_decay: float = 1e-4,
        top_k: Tuple[int, ...] = (1, 10),
        datamodule: Optional[pl.LightningDataModule] = None,
    ):
        super().__init__(top_k=top_k, datamodule=datamodule)

        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.momentum = momentum
        self.weight_decay = weight_decay

        self.backbone = self.load_backbone()
        self.feature_dim = 768
        self.linear_classifier = torch.nn.Linear(self.feature_dim, self.num_classes)

    def load_backbone(self):
        model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
        return model

    def forward(self, x):
        # model expects a dictionary with pixel_values -> tensor
        expected_input = {"pixel_values": x}
        # based on HuggingFace API for extracting features
        feats = self.backbone(**expected_input).pooler_output
        out = self.linear_classifier(feats)
        return out
