import torch
from torch.nn import functional as F
import pytorch_lightning as pl


class PLModel(pl.LightningModule):
    def __init__(self, model, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.lr = kwargs['learning_rate']
        self.min_lr = kwargs['min_learning_rate']
        self.model = model
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, min_lr=self.min_lr)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                "scheduler": scheduler,
                "monitor": "val_loss", 
                "interval": "epoch",
                "frequency": 1
            }
        }

    def training_step(self, batch, batch_idx):
        z = self.model(batch.x)[batch.train_mask]
        y = batch.y[batch.train_mask]

        correct = z.argmax(dim=1).eq(y).sum().item()
        total = len(y)

        train_loss = self.criterion(z, y)
        train_acc = z.argmax(dim=1).eq(y).sum()/len(y)

        self.log("train_acc", train_acc*100, prog_bar=True)

        logs = {"train_loss": train_loss,
                "train_acc": train_acc}

        batch_dictionary = {
            "loss": train_loss,
            "acc": train_acc,
            "log": logs,
            "correct": correct,
            "total": total
        }

        return batch_dictionary

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()

        correct = sum([x['correct'] for x in outputs])
        total = sum([x['total'] for x in outputs])

        self.logger.experiment.add_scalar("Loss/Train",
                                          avg_loss,
                                          self.current_epoch)

        self.logger.experiment.add_scalar("Accuracy/Train",
                                          correct/total,
                                          self.current_epoch)

    def validation_step(self, batch, batch_idx, loader_idx=2):
        z = self.model(batch.x)[batch.val_mask]
        y = batch.y[batch.val_mask]

        correct = z.argmax(dim=1).eq(y).sum().item()
        total = len(y)

        val_loss = self.criterion(z, y)
        val_acc = z.argmax(dim=1).eq(y).sum()/len(y)

        self.log("val_acc", val_acc, prog_bar=True)
        self.log("val_loss", val_loss, prog_bar=False)

        logs = {"val_loss": val_loss,
                "val_acc": val_acc}

        batch_dictionary = {
            "loss": val_loss,
            "acc": val_acc,
            "log": logs,
            "correct": correct,
            "total": total
        }

        return batch_dictionary

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()

        correct = sum([x['correct'] for x in outputs])
        total = sum([x['total'] for x in outputs])

        self.logger.experiment.add_scalar("Loss/Val",
                                          avg_loss,
                                          self.current_epoch)

        self.logger.experiment.add_scalar("Accuracy/Val",
                                          correct/total,
                                          self.current_epoch)

    def test_step(self, batch, batch_idx, loader_idx=2):
        z = self.model(batch.x)[batch.test_mask]
        y = batch.y[batch.test_mask]

        correct = z.argmax(dim=1).eq(y).sum().item()
        total = len(y)

        test_loss = self.criterion(z, y)
        test_acc = z.argmax(dim=1).eq(y).sum()/len(y)

        self.log('test_loss', test_loss, prog_bar=False, on_step=False, on_epoch=True, batch_size=z.size(0))
        self.log('test_acc', test_acc, prog_bar=True, on_step=False, on_epoch=True, batch_size=z.size(0))

        logs = {"test_loss": test_loss,
                "test_acc": test_acc}

        batch_dictionary = {
            "loss": test_loss,
            "acc": test_acc,
            "log": logs,
            "correct": correct,
            "total": total
        }

        return batch_dictionary

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        
        correct = sum([x['correct'] for x in outputs])
        total = sum([x['total'] for x in outputs])

        self.logger.experiment.add_scalar("Loss/Test",
                                          avg_loss,
                                          self.current_epoch)
        self.logger.experiment.add_scalar("Accuracy/Val",
                                          correct/total,
                                          self.current_epoch)



class PLEFATModel(pl.LightningModule):
    def __init__(self, model, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.lr = kwargs['learning_rate']
        self.min_lr = kwargs['min_learning_rate']
        self.model = model
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, min_lr=self.min_lr)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                "scheduler": scheduler,
                "monitor": "val_loss", 
                "interval": "epoch",
                "frequency": 1
            }
        }

    def training_step(self, batch, batch_idx):
        z = self.model(batch)[batch.train_mask]
        y = batch.y[batch.train_mask]

        correct = z.argmax(dim=1).eq(y).sum().item()
        total = len(y)

        train_loss = self.criterion(z, y)
        train_acc = z.argmax(dim=1).eq(y).sum()/len(y)

        self.log("train_acc", train_acc*100, prog_bar=True)

        logs = {"train_loss": train_loss,
                "train_acc": train_acc}

        batch_dictionary = {
            "loss": train_loss,
            "acc": train_acc,
            "log": logs,
            "correct": correct,
            "total": total
        }

        return batch_dictionary

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()

        correct = sum([x['correct'] for x in outputs])
        total = sum([x['total'] for x in outputs])

        self.logger.experiment.add_scalar("Loss/Train",
                                          avg_loss,
                                          self.current_epoch)

        self.logger.experiment.add_scalar("Accuracy/Train",
                                          correct/total,
                                          self.current_epoch)

    def validation_step(self, batch, batch_idx, loader_idx=2):
        z = self.model(batch)[batch.val_mask]
        y = batch.y[batch.val_mask]

        correct = z.argmax(dim=1).eq(y).sum().item()
        total = len(y)

        val_loss = self.criterion(z, y)
        val_acc = z.argmax(dim=1).eq(y).sum()/len(y)

        self.log("val_acc", val_acc, prog_bar=True)
        self.log("val_loss", val_loss, prog_bar=False)

        logs = {"val_loss": val_loss,
                "val_acc": val_acc}

        batch_dictionary = {
            "loss": val_loss,
            "acc": val_acc,
            "log": logs,
            "correct": correct,
            "total": total
        }

        return batch_dictionary

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()

        correct = sum([x['correct'] for x in outputs])
        total = sum([x['total'] for x in outputs])

        self.logger.experiment.add_scalar("Loss/Val",
                                          avg_loss,
                                          self.current_epoch)

        self.logger.experiment.add_scalar("Accuracy/Val",
                                          correct/total,
                                          self.current_epoch)

    def test_step(self, batch, batch_idx, loader_idx=2):
        z = self.model(batch)[batch.test_mask]
        y = batch.y[batch.test_mask]

        correct = z.argmax(dim=1).eq(y).sum().item()
        total = len(y)

        test_loss = self.criterion(z, y)
        test_acc = z.argmax(dim=1).eq(y).sum()/len(y)

        self.log('test_loss', test_loss, prog_bar=False, on_step=False, on_epoch=True, batch_size=z.size(0))
        self.log('test_acc', test_acc, prog_bar=True, on_step=False, on_epoch=True, batch_size=z.size(0))

        logs = {"test_loss": test_loss,
                "test_acc": test_acc}

        batch_dictionary = {
            "loss": test_loss,
            "acc": test_acc,
            "log": logs,
            "correct": correct,
            "total": total
        }

        return batch_dictionary

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        
        correct = sum([x['correct'] for x in outputs])
        total = sum([x['total'] for x in outputs])

        self.logger.experiment.add_scalar("Loss/Test",
                                          avg_loss,
                                          self.current_epoch)
        self.logger.experiment.add_scalar("Accuracy/Val",
                                          correct/total,
                                          self.current_epoch)

