from abc import ABC, abstractmethod
import torch
import torchmetrics
import pytorch_lightning as pl

from fairret.metrics import FairMetricCollection
from fairret.loss.base import FairnessLoss


class Model(pl.LightningModule, ABC):
    name = None

    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        if cls.name is None:
            raise ValueError(f"Model class {cls} must set a 'name'.")

    def __init__(self,
                 feat_dim=None,
                 sens_dim=None,
                 simple_sens_cols=None,
                 lr=1e-3,
                 fairret: FairnessLoss = None,
                 fairret_strength=1.,
                 epoch_start_fairret=0,
                 balancing=False,
                 use_simple_sens=False,
                 log_simple_sens=True,
                 log_fairret_internal=False,
                 log_unnormalized=False,
                 **kwargs):
        super().__init__()

        self.feat_dim = feat_dim
        self.sens_dim = sens_dim
        if not log_simple_sens:
            simple_sens_cols = None  # Avoid using simple sens columns if we aren't logging them
        self.simple_sens_cols = simple_sens_cols
        self.lr = lr
        self.fairret = fairret
        if fairret_strength is None:
            fairret_strength = 0.
        self.fairret_strength = fairret_strength
        self.epoch_start_fairret = epoch_start_fairret
        self.balancing = balancing
        self.use_simple_sens = use_simple_sens
        self.log_fairret_internal = log_fairret_internal
        self.log_unnormalized = log_unnormalized

        self.loss_fn = bce_loss_fn

        self.train_metrics = torchmetrics.MetricCollection({
            'auroc': torchmetrics.classification.BinaryAUROC()
        },
            prefix="train/",
            compute_groups=False)
        self.train_fair_metrics = FairMetricCollection(
            sens_dim=self.sens_dim,
            simple_sens_cols=self.simple_sens_cols,
            normalized=not log_unnormalized,
            prefix="train/")
        self.val_metrics = self.train_metrics.clone(prefix="val/")
        self.val_fair_metrics = self.train_fair_metrics.clone(prefix="val/")
        self.test_metrics = self.train_metrics.clone(prefix="test/")
        self.test_fair_metrics = self.train_fair_metrics.clone(prefix="test/")
        if fairret is not None and log_fairret_internal:
            self.train_fairret_metrics = fairret.internal_metrics(prefix="train/")
            self.val_fairret_metrics = fairret.internal_metrics(prefix="val/")
            self.test_fairret_metrics = fairret.internal_metrics(prefix="test/")
        else:
            self.train_fairret_metrics = None
            self.val_fairret_metrics = None
            self.test_fairret_metrics = None

        if fairret is not None and fairret.name == 'ffb':
            assert self.use_simple_sens

        self.balancing_weight = None

        if len(kwargs) > 0:
            raise ValueError(f"The following kwargs were not used: {kwargs}")

    def setup(self, stage):
        if stage == 'fit':
            self.train_metrics.reset()
            self.balancing_weight = None
        elif stage == 'validate':
            self.val_metrics.reset()
        elif stage == 'test':
            self.test_metrics.reset()

    @abstractmethod
    def forward(self, feat, sens):
        pass

    def step(self, batch, _batch_idx, stage):
        feat, sens, label, idx = batch
        logit = self(feat, sens)
        if logit.isnan().all():
            raise ValueError("NaN values in logits.")

        bce_loss = self.loss_fn(logit, label.float(), idx)
        self.log(f"{stage}/bce_loss", bce_loss.detach(), on_step=False, on_epoch=True, logger=True)

        if self.fairret_strength != 0. and self.current_epoch >= self.epoch_start_fairret:
            fairret_internal_metrics = getattr(self, f"{stage}_fairret_metrics")
            if self.use_simple_sens:
                sens_ = self._gather_simple_sens(sens)
            else:
                sens_ = sens
            fairret_loss = self.fairret(logit, feat, sens_, label,
                                        metrics=fairret_internal_metrics)

            if fairret_loss.isnan():
                raise ValueError("NaN values in fairret output.")

            self.log(f"{stage}/fairret_loss", fairret_loss.detach(), on_step=False, on_epoch=True, logger=True)

            if self.balancing_weight is not None:
                fairret_loss = self.balancing_weight * fairret_loss
            loss = bce_loss + self.fairret_strength * fairret_loss
        else:
            loss = bce_loss
        self.log(f"{stage}/loss", loss.detach(), on_step=False, on_epoch=True, prog_bar=True, logger=True)

        # Note: TorchMetrics are maintained separately and only aggregated at the end of the epoch.
        pred = torch.sigmoid(logit).detach()
        metrics = getattr(self, f"{stage}_metrics")
        metrics.update(pred, label)
        fair_metrics = getattr(self, f"{stage}_fair_metrics")
        fair_metrics.update(pred, feat, sens, label)
        return loss

    def on_epoch_end(self, stage):
        # if not self.trainer.sanity_checking:
        metrics = getattr(self, f"{stage}_metrics")
        self.log_dict(metrics.compute(), on_step=False, on_epoch=True, logger=True)
        metrics.reset()

        fair_metrics = getattr(self, f"{stage}_fair_metrics")
        self.log_dict(fair_metrics.compute(), on_step=False, on_epoch=True, logger=True)
        fair_metrics.reset()

        fairret_internal_metrics = getattr(self, f"{stage}_fairret_metrics")
        if fairret_internal_metrics is not None:
            self.log_dict(fairret_internal_metrics.compute(), on_step=False, on_epoch=True, logger=True)
            fairret_internal_metrics.reset()

    def training_step(self, *args):
        return self.step(*args, stage="train")

    def on_train_epoch_end(self) -> None:
        self.on_epoch_end("train")

        if self.balancing and self.fairret_strength != 0. and self.trainer.current_epoch == self.epoch_start_fairret:
            # TODO: could already compute the balancing weight before the epoch_start_fairret, but this is unintuitive

            bce_loss = self.trainer.logged_metrics['train/bce_loss']
            fairret_loss = self.trainer.logged_metrics['train/fairret_loss']
            self.balancing_weight = (bce_loss / fairret_loss).item()

    def validation_step(self, *args):
        return self.step(*args, stage="val")

    def on_validation_epoch_end(self) -> None:
        self.on_epoch_end("val")

    def test_step(self, *args):
        return self.step(*args, stage="test")

    def on_test_epoch_end(self) -> None:
        self.on_epoch_end("test")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def on_load_checkpoint(self, checkpoint):
        self.setup("fit")
        super().on_load_checkpoint(checkpoint)

    def _gather_simple_sens(self, sens):
        assert self.simple_sens_cols
        sens = sens[:, self.simple_sens_cols]

        if (sens.sum(dim=1) != 1).any():
            raise ValueError("Simple sensitive features must be one-hot encodings of demographic group.")
        return sens


def bce_loss_fn(logit, label, *_args):
    return torch.nn.functional.binary_cross_entropy_with_logits(logit, label.float())
