from models.base_model import BaseModel
from pl_bolts.models.self_supervised import SimCLR
import pytorch_lightning as pl
from typing import Optional, Tuple
import torch

from models.resnet_2d.shapes_model import ShapesModule


class SimCLRLinearEvalModule(ShapesModule):
    """Trains a fresh linear classifier while keeping other weights frozen."""

    def __init__(
        self,
        learning_rate: float = 1e-1,
        optimizer: str = "adam",
        momentum: float = 0.9,
        weight_decay: float = 1e-4,
        top_k: Tuple[int, ...] = (1, 10),
        datamodule: Optional[pl.LightningDataModule] = None,
    ):
        super().__init__(top_k=top_k, datamodule=datamodule)

        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.momentum = momentum
        self.weight_decay = weight_decay

        self.top_k = top_k
        # num_classes is set within _setup_loader_names()
        self.num_classes = None
        self.datamodule = datamodule

        self._setup_loader_names()
        self.setup_accuracy_metrics()

        self.backbone = self.model
        self.backbone.eval()
        self.linear_classifier = torch.nn.Linear(2048, self.num_classes)

    def load_backbone(self):
        weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
        simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
        return simclr

    def on_train_epoch_start(self) -> None:
        self.backbone.eval()

    def forward(self, x):
        with torch.no_grad():
            feats = self.backbone(x)
        out = self.linear_classifier(feats)
        return out


class SimCLRFineTuneModule(ShapesModule):
    """Trains a fresh linear classifier while keeping other weights frozen."""

    def __init__(
        self,
        learning_rate: float = 1e-1,
        optimizer: str = "adam",
        momentum: float = 0.9,
        weight_decay: float = 1e-4,
        top_k: Tuple[int, ...] = (1, 10),
        datamodule: Optional[pl.LightningDataModule] = None,
    ):
        super().__init__(top_k=top_k, datamodule=datamodule)

        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.momentum = momentum
        self.weight_decay = weight_decay

        self.top_k = top_k
        # num_classes is set within _setup_loader_names()
        self.num_classes = None
        self.datamodule = datamodule

        self._setup_loader_names()
        self.setup_accuracy_metrics()

        self.backbone = self.model
        self.linear_classifier = torch.nn.Linear(2048, self.num_classes)

    def load_backbone(self):
        weight_path = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"
        simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
        return simclr

    def forward(self, x):
        feats = self.backbone(x)
        out = self.linear_classifier(feats)
        return out
