from pytorch_lightning import LightningModule
import pytorch_lightning as pl
from abc import abstractmethod
from typing import Optional, Union
from torch import Tensor

import torch.nn.functional as F
import torchmetrics


class BaseModel(LightningModule):
    def __init__(
        self,
        lr: Union[float, str] = "auto",
        momentum: float = 0.9,
        weight_decay: float = 1e-6,
        network: str = "resnet50",
        low_res: bool = False,
        weight_decay_trick: bool = True,
        datamodule: Optional[pl.LightningDataModule] = None,
    ) -> None:
        super().__init__()
        self.datamodule = datamodule
        self.num_classes = self.datamodule.num_classes

        self.network = network
        self.low_res = low_res
        self.weight_decay_trick = weight_decay_trick

        self.load_modules()

        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay

        self.online_train_accuracy = torchmetrics.Accuracy(
            task="multiclass", num_classes=self.num_classes
        )
        self.create_val_accuracy_metrics()
        self.online_test_accuracy = torchmetrics.Accuracy(
            task="multiclass", num_classes=self.num_classes
        )

    def create_val_accuracy_metrics(self):
        for val_dataset_name in self.datamodule.val_dataset_names:
            print("val_dataset naem", val_dataset_name)
            setattr(
                self,
                f"online_{val_dataset_name}_accuracy",
                torchmetrics.Accuracy(task="multiclass", num_classes=self.num_classes),
            )

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        val_dataset_name = self.datamodule.val_dataset_names[dataloader_idx]
        images, targets = batch[0], batch[1]
        features = self.forward(images).flatten(start_dim=1)
        loss_online_probe = self.compute_online_probe(
            features, targets, val_dataset_name
        )
        return loss_online_probe

    def compute_online_probe(self, z: Tensor, y: Tensor, stage: str) -> Tensor:
        logits_online_probe = self.online_classifier(z.detach())
        y = y.squeeze()
        loss_online_probe = F.cross_entropy(logits_online_probe, y)
        self.log(f"{stage}_online_linear_probe_loss", loss_online_probe, sync_dist=True)

        accuracy_metric = getattr(self, f"online_{stage}_accuracy")
        accuracy_metric(F.softmax(logits_online_probe, dim=-1), y)
        self.log(
            f"online_{stage}_accuracy",
            accuracy_metric,
            prog_bar=True,
            sync_dist=True,
            on_epoch=True,
            on_step=False,
        )
        return loss_online_probe

    @abstractmethod
    def load_modules(self):
        pass
