import torch
import torch.nn as nn
import torchmetrics
from torch import Tensor
import torch.nn.functional as F
from models.base_model import BaseModel
import pytorch_lightning as pl
from typing import Optional


class FashionCNN(BaseModel):
    """CNN Architecture based on https://www.kaggle.com/pankajj/fashion-mnist-with-pytorch-93-accuracy"""

    def __init__(
        self,
        learning_rate: float = 1e-3,
        optimizer: str = "adam",
        datamodule: Optional[pl.LightningDataModule] = None,
    ):
        super().__init__()

        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.datamodule = datamodule

        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        self.fc1 = nn.Linear(in_features=64 * 6 * 6, out_features=600)
        self.drop = nn.Dropout2d(0.25)
        self.fc2 = nn.Linear(in_features=600, out_features=120)
        self.fc3 = nn.Linear(in_features=120, out_features=10)

        self.train_accuracy = torchmetrics.Accuracy()
        self.val_accuracy = torchmetrics.Accuracy()

        # will be set at start of testing
        self.test_steps = None

    def setup(self, stage=None) -> None:
        if stage == "test":
            self.setup_test_accuracy()
        return super().setup(stage=stage)

    def setup_test_accuracy(self):
        self.test_steps = self.trainer.datamodule.test_steps
        for step in self.test_steps:
            metric = f"{step}_test_accuracy"
            if not hasattr(self, metric):
                setattr(self, metric, torchmetrics.Accuracy())

    def shared_step(self, batch: Tensor, stage: str = "train"):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        accuracy_metric = getattr(self, f"{stage}_accuracy")
        accuracy_metric(F.softmax(y_hat, dim=-1), y)
        self.log(f"{stage}_loss", loss, sync_dist=True)
        self.log(
            f"{stage}_accuracy",
            accuracy_metric,
            prog_bar=True,
            sync_dist=True,
            on_epoch=True,
            on_step=False,
        )

        return loss

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.drop(out)
        out = self.fc2(out)
        out = self.fc3(out)
        return out

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, stage="train")
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch, stage="val")
        return loss

    def test_step(self, batch, batch_idx, dataloader_idx):
        test_step = self.test_steps[dataloader_idx]
        loss = self.shared_step(batch, stage=f"{test_step}_test")
        return loss

    def configure_optimizers(self):
        if self.optimizer == "adam":
            return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        raise ValueError(f"optimizer {self.optimizer} not implemented")


if __name__ == "__main__":
    from datasets import fashion_mnist
    import pytorch_lightning as pl

    trainer = pl.Trainer(max_epochs=5, gpus=1)
    dm = fashion_mnist.RotatedFashionMNISTDataModule(rotate_every_n_train_images=2)
    model = FashionCNN()
    trainer.fit(model, datamodule=dm)
    trainer.test(model, datamodule=dm)
