"""
Lighting model.
"""
import pytorch_lightning as pl
import torch
from .layers import AnomalyDtectionModel
from easydict import EasyDict
from typing import Tuple
from .utils import compute_anomaly_accuracy


class LightningModel(pl.LightningModule):
    def __init__(self, config: EasyDict):
        """
        The graph Model.
        Args:
            config: The config.
        """
        super().__init__()
        self.model = AnomalyDtectionModel(config=config)
        self.lr = config.lr
        self.wd = config.wd
        self.decay_factor = config.decay_factor
        self.loss_fun = torch.nn.BCEWithLogitsLoss()
        self.acc_fun = compute_anomaly_accuracy

    def forward(self, X):
        return self.model(X)

    def training_step(self, batch: Tuple, batch_idx: int) -> torch.float:
        """
        The training step.
        Args:
            batch: The batch.
            batch_idx: The batch index-not used.

        Returns: The loss.

        """
        self.model.train()
        result = self.model(A=batch[0])
        loss = self.loss_fun(result, batch[-1])
        acc = self.acc_fun(result, batch[-1])
        self.log("train_loss", loss)
        self.log('train_acc', acc)
        return loss

    def validation_step(self, batch: Tuple, batch_idx: int):
        """
        The validation step.
        Args:
            batch: The batch.
            batch_idx: The batch index-not used.

        Returns: The loss.

        """
        self.model.eval()
        with torch.no_grad():
            result = self.model(A=batch[0])
            loss = self.loss_fun(result, batch[-1])
            acc = self.acc_fun(result, batch[-1])
        self.log("val_loss", loss)
        self.log('val_acc', acc)
        return loss

    def test_step(self, batch: Tuple, batch_idx: int):
        """
        The validation step.
        Args:
            batch: The batch.
            batch_idx: The batch index-not used.

        Returns: The loss.

        """
        self.model.eval()
        with torch.no_grad():
            result = self.model(A=batch[0])
            loss = self.loss_fun(result, batch[-1])
            acc = self.acc_fun(result, batch[-1])
        self.log("test_loss", loss)
        self.log('test_acc', acc)
        return loss

    def configure_optimizers(self):
        """
        Return optimizer.
        """

        optimizer = torch.optim.Adam(
            params=self.model.parameters(),
            lr=self.lr,
            weight_decay=self.wd,
            eps=1e-07,
        )

        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=self.decay_factor,mode = 'min')
        lr_scheduler_config = {
            "scheduler": lr_scheduler,
            "interval": "epoch",
            "monitor": "train_loss"
        }

        return [optimizer], lr_scheduler_config
