from models.vicreg.vicreg import VICRegLieModule
from torch import Tensor
import torch
import torch.nn.functional as F
from models.lie_ssl.finetuning import BaseFineTuning


class VICRegLieFinetuningBackpropLosses(BaseFineTuning):
    model_class = VICRegLieModule

    def load_backbone(self):
        if self.lie_module_path is not None:
            lie_ssl = self.model_class.load_from_checkpoint(
                self.lie_module_path,
                datamodule=self.datamodule,
            )
            print("Loaded backbone from", self.lie_module_path)
        else:
            raise NotImplementedError("Lie SSL path must be given")
        return lie_ssl

    def forward(self, x):
        z = self.backbone(x)
        z_transformed = self.backbone.transform(z)
        return z, z_transformed

    def unpack_batch(self, batch):
        if len(batch) == 2:
            return batch
        elif len(batch) == 3:
            x, y, _ = batch
            return x, y
        raise ValueError(f"can't unpack batch of size {len(batch)}")

    def shared_step(self, batch, stage: str = "train"):
        x, y = self.unpack_batch(batch)
        z, z_transformed = self(x)
        y_hat, y_transformed_hat = self.linear_classifier(z), self.linear_classifier(
            z_transformed
        )

        z_loss = self.loss_function(y_hat, y)
        z_transfomed_loss = self.loss_function(y_transformed_hat, y)
        loss = z_loss + z_transfomed_loss

        self.log_losses(z_loss, z_transfomed_loss, loss, stage)
        self.log_accuracies(y_hat, y_transformed_hat, y, stage)
        return loss

    def log_losses(self, z_loss, z_transformed_loss, loss, stage: str):
        """Logs losses which are float scalars for the stage"""
        self.log(
            f"{stage}_z_loss",
            z_loss,
            sync_dist=True,
            # loader names are used instead
            add_dataloader_idx=False,
        )
        self.log(
            f"{stage}_z_transformed_loss",
            z_transformed_loss,
            sync_dist=True,
            add_dataloader_idx=False,
        )
        self.log(
            f"{stage}_loss",
            loss,
            sync_dist=True,
            add_dataloader_idx=False,
        )

    def log_accuracies(
        self, y_hat: Tensor, y_transformed_hat: Tensor, y: Tensor, stage: str
    ):
        batch_size = y_hat.shape[0]

        for k in self.top_k:
            accuracy_metric = getattr(self, f"{stage}_top_{k}_accuracy")
            # compute accuracy only for the untransformed sample
            accuracy_metric(F.softmax(y_hat, dim=-1), y)
            self.log(
                f"{stage}_top_{k}_accuracy",
                accuracy_metric,
                prog_bar=True,
                sync_dist=True,
                on_epoch=True,
                on_step=False,
                batch_size=batch_size,
                # loader names are used instead
                add_dataloader_idx=False,
            )


class VICRegLieLinearEvalBackpropLosses(VICRegLieFinetuningBackpropLosses):
    model_class = VICRegLieModule

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.backbone.eval()

    def on_train_epoch_start(self) -> None:
        self.backbone.eval()

    def forward(self, x):
        with torch.no_grad():
            z = self.backbone(x)
            z_transformed = self.backbone.transform(z)
        return z, z_transformed

class VICRegNoTransformFinetuningBackpropLosses(VICRegLieFinetuningBackpropLosses):
    model_class = VICRegLieModule

    def forward(self, x):
        z = self.backbone(x)
        return z

    def shared_step(self, batch, stage: str = "train"):
        x, y = self.unpack_batch(batch)
        z = self(x)
        y_hat = self.linear_classifier(z)

        z_loss = self.loss_function(y_hat, y)
        loss = z_loss 

        self.log_losses(z_loss, z_loss, loss, stage)
        self.log_accuracies(y_hat, y_hat, y, stage)
        return loss

class VICRegNoTransformLinearEvalBackpropLosses(VICRegLieLinearEvalBackpropLosses):
    """
    Does not require a LieModulePath + no transform at inference
    """
    model_class = VICRegLieModule

    def shared_step(self, batch, stage: str = "train"):
        x, y = self.unpack_batch(batch)
        z = self(x)
        y_hat = self.linear_classifier(z)

        z_loss = self.loss_function(y_hat, y)
        loss = z_loss

        self.log_losses(z_loss, z_loss, loss, stage)
        self.log_accuracies(y_hat, y_hat, y, stage)
        return loss

    def forward(self, x):
        with torch.no_grad():
            z = self.backbone(x)
        return z