import torch
import pytorch_lightning as pl
from torchmetrics import Accuracy
import hydra

class MNISTTask(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)
        self.model = hydra.utils.instantiate(self.hparams.model)
        self.criterion = hydra.utils.instantiate(self.hparams.criterion)
        self.optimizer = hydra.utils.instantiate(self.hparams.optimizer)
        self.validation_step_outputs = []

    def training_step(self, batch, batch_idx):
        X, y = batch
        y_preds = self.model(X)
        loss, _ = self.criterion(y_preds, batch)
        acc = torch.sum(torch.eq(torch.argmax(y_preds, -1), y).to(torch.float32)) / len(y)
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        return {'loss': loss}

    def validation_step(self, batch, batch_idx):
        X, y = batch
        y_preds = self.model(X)
        loss, _ = self.criterion(y_preds, batch)
        acc = torch.sum(torch.eq(torch.argmax(y_preds, -1), y).to(torch.float32)) / len(y)
        self.validation_step_outputs.append(loss)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return {'loss': loss}

    def on_validation_epoch_end(self):
        epoch_average_loss = torch.stack(self.validation_step_outputs).mean()
        self.log("validation_epoch_average_loss", epoch_average_loss)
        self.validation_step_outputs.clear()  # free memory
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.optimizer.lr)
        lr_scheduler = {'scheduler': torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = 0.95),
                        'name': 'expo_lr'}
        return [optimizer], [lr_scheduler]