import lightning.pytorch as pl
from torch import nn, tensor
import torch

class DiscriminationModel(pl.LightningModule):
    def __init__(self,
                 model_dim: int,
                 lr: float = 1e-4):
        super().__init__()

        # See https://github.com/hihihihiwsf/AST/blob/main/gan_transformer.py
        self.model = nn.Sequential(
            nn.Linear(model_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

        self.BCE_loss = nn.BCELoss()
        self.lr = lr

    def _masked_mean_pooling(self,
                             encoding: tensor,
                             mask: tensor):
        encoding = encoding.masked_fill(mask.unsqueeze(-1), 0.0)
        lengths = (~mask).sum(dim=1, keepdim=True)
        pooled = encoding.sum(dim=1) / lengths.clamp(min=1)
        return pooled

    def forward(self, encoding, mask):
        return self.model(self._masked_mean_pooling(encoding, mask))
    
    def _no_pool_forward(self, encoding):
        return self.model(encoding)
    
    def training_step(self, batches, *args):
        batch_real, batch_sim = batches['real'], batches['simulated']

        real_pred = self.forward(batch_real['encoding'], batch_real['mask'])
        sim_pred = self.forward(batch_sim['encoding'], batch_sim['mask'])

        loss_real = self.BCE_loss(real_pred, torch.ones_like(real_pred))
        loss_sim = self.BCE_loss(sim_pred, torch.zeros_like(sim_pred))
        loss = loss_real + loss_sim

        # self.log("train_loss/discriminator/real", loss_real, on_step=True, on_epoch=False, sync_dist=True, batch_size=real_pred.shape[0])
        # self.log("train_loss/discriminator/simulated", loss_sim, on_step=True, on_epoch=False, sync_dist=True, batch_size=sim_pred.shape[0])
        # self.log("train_loss/discriminator/total", loss, on_step=True, on_epoch=False, sync_dist=True, batch_size=real_pred.shape[0])

        return loss, loss_real, loss_sim

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
