import pytorch_lightning as pl
import torch
import torchmetrics
from torch import nn, optim, utils
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.models import vgg16
from torchvision.transforms import ToTensor


class Classifier(pl.LightningModule):

    def __init__(self, num_classes: int, num_epochs, lr, eta_min):
        super().__init__()
        self.num_classes = num_classes
        self.lr = lr
        self.eta_min = eta_min
        self.num_epochs = num_epochs
        self.classifier = vgg16(num_classes=num_classes)
        self.criterion = nn.CrossEntropyLoss()
        self.mse_loss = nn.MSELoss()
        self.train_acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
        self.val_acc = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)

    def forward(self, x):
        return self.classifier(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)

        # loss for not learn from background
        background_img = torch.zeros_like(x)
        y_hat_backgound = self(background_img)
        loss_backgound = self.mse_loss(y_hat_backgound, torch.zeros_like(y_hat_backgound))
        total_loss = loss_backgound + loss

        # logging to tensorboard by default
        self.log("training_loss", loss, prog_bar=True)
        # self.log("background_loss", loss_backgound, prog_bar=True)
        # self.log("total_loss", total_loss, prog_bar=True)

        # log accuracy
        self.train_acc(y_hat, y)
        self.log("train_acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
        return total_loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        self.val_acc(y_hat, y)
        self.log("val_acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.lr)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.num_epochs, eta_min=self.eta_min)
        return [optimizer], [scheduler]


def train(
        num_classes: int,
        num_epochs=30,
        devices=2,
        lr=1e-4,
        eta_min=1e-6,
        data_root=None,
        train_split=0.8,
        validation_in_train=False):
    classifier = Classifier(num_classes=num_classes, num_epochs=num_epochs, lr=lr, eta_min=eta_min)
    assert data_root is not None, "Please provide a valid dataset root"
    dataset = ImageFolder(root=data_root, transform=ToTensor())
    train_size = int(len(dataset) * train_split)
    valid_size = len(dataset) - train_size
    train_dataset, validation_dataset = utils.data.random_split(dataset, [train_size, valid_size])

    # reset the training set to be the whole dataset
    if validation_in_train:
        train_dataset = dataset
    train_dataloader = DataLoader(train_dataset, batch_size=64, num_workers=10, persistent_workers=True, shuffle=True)
    val_dataloader = DataLoader(
        validation_dataset, batch_size=64, num_workers=10, persistent_workers=True, shuffle=True)
    trainer = pl.Trainer(max_epochs=num_epochs, accelerator='gpu', devices=devices)
    trainer.fit(model=classifier, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
