from models.base_model import ShapesBaseModel
from typing import Tuple, Optional
import pytorch_lightning as pl
import torch
import timm


class ViTLinearEval(ShapesBaseModel):
    """Trains a fresh linear classifier while keeping other weights frozen."""

    def __init__(
        self,
        learning_rate: float = 1e-1,
        optimizer: str = "adam",
        momentum: float = 0.9,
        weight_decay: float = 1e-4,
        top_k: Tuple[int, ...] = (1, 10),
        datamodule: Optional[pl.LightningDataModule] = None,
    ):
        super().__init__(top_k=top_k, datamodule=datamodule)

        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.momentum = momentum
        self.weight_decay = weight_decay

        self.backbone = self.load_backbone()
        self.feature_dim = 768
        self.linear_classifier = torch.nn.Linear(self.feature_dim, self.num_classes)

    def load_backbone(self):
        # trained on ImageNet
        vit = timm.create_model("vit_base_patch16_224", pretrained=True)
        return vit

    def on_train_epoch_start(self) -> None:
        self.backbone.eval()

    def forward(self, x):
        with torch.no_grad():
            feats = self.backbone.forward_features(x)
        out = self.linear_classifier(feats)
        return out


class ViTFinetuner(ShapesBaseModel):
    """Trains a fresh linear classifier while keeping other weights frozen."""

    def __init__(
        self,
        learning_rate: float = 1e-1,
        optimizer: str = "adam",
        momentum: float = 0.9,
        weight_decay: float = 1e-4,
        top_k: Tuple[int, ...] = (1, 10),
        datamodule: Optional[pl.LightningDataModule] = None,
    ):
        super().__init__(top_k=top_k, datamodule=datamodule)

        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.momentum = momentum
        self.weight_decay = weight_decay

        self.backbone = self.load_backbone()
        self.feature_dim = 768
        self.linear_classifier = torch.nn.Linear(self.feature_dim, self.num_classes)

    def load_backbone(self):
        # trained on ImageNet
        vit = timm.create_model("vit_base_patch16_224", pretrained=True)
        return vit

    def forward(self, x):
        feats = self.backbone.forward_features(x)
        out = self.linear_classifier(feats)
        return out
