from typing import Dict, Tuple

import numpy as np
import torch
import torch.nn as nn
from lightning import LightningModule

from evaluation.predictive.predictor import Predictor


class PredictorPL(LightningModule):

    def __init__(self, n_features: int, num_layers: int, hidden_size: int, cutoff: float, lr: float) -> None:
        super().__init__()

        self.predictor = Predictor(n_features, num_layers, hidden_size, cutoff)

        self.lr = lr
        self.loss_fn = nn.L1Loss()

        self.losses = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.predictor(x)

    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        x, y = batch
        out = self(x)
        loss = self.loss_fn(out, y)
        self.log('loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> None:
        x, y = batch
        out = self(x)
        loss = self.loss_fn(out, y)
        self.losses = np.append(self.losses, loss.item())

    def on_validation_epoch_end(self) -> None:
        loss = np.mean(self.losses)
        self.log(f'epoch{self.current_epoch}_loss', loss)

    def on_validation_epoch_start(self) -> None:
        self.losses = np.array([])

    def configure_optimizers(self) -> Dict:
        optimizer = torch.optim.Adam(self.predictor.parameters(), lr=self.lr)
        return {'optimizer': optimizer}
