from typing import Dict, Tuple

import numpy as np
import torch
import torch.nn as nn
from lightning import LightningModule
from sklearn.metrics import roc_auc_score

from evaluation.discriminative.discriminator import Discriminator


class DiscriminatorPL(LightningModule):

    def __init__(self, n_features: int, hidden_size: int, num_layers: int, lr: float) -> None:
        super().__init__()

        self.discriminator = Discriminator(n_features, 1, hidden_size, num_layers)

        self.lr = lr
        self.loss_fn = nn.BCELoss()

        self.labels, self.preds = None, None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.discriminator(x)

    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        x, y = batch
        o = self(x)
        loss = self.loss_fn(o, y)
        self.log('loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> None:
        x, y = batch
        o = self(x).round()
        self.labels = np.append(self.labels, y.detach().cpu().numpy())
        self.preds = np.append(self.preds, o.detach().cpu().numpy())

    def on_validation_epoch_end(self) -> None:
        auroc = roc_auc_score(self.labels, self.preds)
        discriminative_score = abs(auroc - .5)
        self.log(f'epoch{self.current_epoch}_discriminative_score', discriminative_score)

    def on_validation_epoch_start(self) -> None:
        self.labels = np.array([])
        self.preds = np.array([])

    def configure_optimizers(self) -> Dict:
        optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr)
        return {'optimizer': optimizer}
