import torch
from torch import Tensor
import pytorch_lightning as pl
import torch.nn.functional as F
import torchmetrics
import torchvision
from models.base_model import BaseModel
from typing import Optional


class ImageClassificationModule(BaseModel):
    def __init__(
        self,
        learning_rate: float = 1e-1,
        optimizer: str = "adam",
        momentum: float = 0.9,
        weight_decay: float = 1e-4,
        num_classes: int = 10,
        test_every_n_epoch: int = 10,
        get_fov: bool = False,
        datamodule: Optional[pl.LightningDataModule] = None,
    ):
        super().__init__()

        self.optimizer = optimizer
        self.model = self.create_resnet(num_classes)
        self.learning_rate = learning_rate
        self.num_classes = num_classes
        self.test_every_n_epoch = test_every_n_epoch
        self.get_fov = get_fov
        self.datamodule = datamodule

        self.train_accuracy = torchmetrics.Accuracy()
        self.val_accuracy = torchmetrics.Accuracy()
        self.test_accuracy = torchmetrics.Accuracy()

    def create_resnet(self, num_classes):
        return torchvision.models.resnet50(pretrained=False, num_classes=num_classes)

    def forward(self, x):
        return self.model(x)

    def shared_step(self, batch: Tensor, stage: str = "train"):
        # The model expects an image tensor of shape (B, C, H, W)
        if self.get_fov:
            x, y, _ = batch
        else:
            x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        accuracy_metric = getattr(self, f"{stage}_accuracy")
        accuracy_metric(F.softmax(y_hat, dim=-1), y)
        self.log(f"{stage}_loss", loss, sync_dist=True)
        self.log(
            f"{stage}_accuracy",
            accuracy_metric,
            prog_bar=True,
            sync_dist=True,
            on_epoch=True,
            on_step=False,
        )

        return loss

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, stage="train")
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch, stage="val")
        return loss

    def test_step(self, batch, batch_idx):
        loss = self.shared_step(batch, stage="test")
        return loss

    def configure_optimizers(self):
        if self.optimizer == "adam":
            return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        elif self.optimizer == "sgd":
            return self.sgd()
        raise ValueError(f"optimizer {self.optimizer} not implemented")

    def sgd(self):
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.hparams.learning_rate,
            momentum=self.hparams.momentum,
            weight_decay=self.hparams.weight_decay,
        )
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
        return [optimizer], [scheduler]
