import torch
from sklearn.metrics import roc_auc_score

from .ff_net import FeedForwardNet


class AdvDebiasing(FeedForwardNet):
    name = 'adv_debias'

    """
    Modified from https://github.com/ahxt/fair_fairness_benchmark/
    which in turn modified it from https://github.com/hanzhaoml/ICLR2020-CFair/blob/master/models.py
    "Multi-layer perceptron with adversarial training for fairness".

    NOTE: this could be implemented as a fairret, but it would require that we pass hidden layer activations
    to the fairret, which would require a more changes to the rest of the pipeline. 
    Especially if we would also use LAFTR as a fairret.
    """

    def __init__(self,
                 adv_dim=None,
                 label_given=False,
                 **kwargs):
        kwargs['fairret'] = None
        super().__init__(**kwargs)

        if adv_dim is None:
            adv_dim = []
        self.adv_dim = adv_dim
        self.label_given = label_given

        self.adv_net = None

    def setup(self, stage):
        if stage != "fit":
            return

        super().setup(stage)

        if len(self.dim) == 0:
            raise ValueError("Adversarial debiasing only works for networks with at least one hidden layer!")

        adv_input_dim = self.dim[-1]
        if self.label_given:
            adv_input_dim += 1
        self.adv_net = self.build_net(adv_input_dim, self.adv_dim, output_dim=len(self.simple_sens_cols))

    def forward(self, feat, _sens, label=None):
        for i in range(len(self.net) - 1):
            layer = self.net[i]
            feat = layer(feat)
        logit = self.net[-1](feat)[..., 0]

        adv_feat = grad_reverse(feat)

        if self.label_given:
            assert label is not None
            adv_feat = torch.cat([adv_feat, label.unsqueeze(-1)], dim=-1)

        for i in range(len(self.adv_net) - 1):
            layer = self.adv_net[i]
            adv_feat = layer(adv_feat)
        adv_logit = self.adv_net[-1](adv_feat)

        return logit, adv_logit

    def step(self, batch, _batch_idx, stage):
        # Mostly copy-pasted from base Model...

        feat, sens, label, idx = batch
        logit, adv_logit = self(feat, sens, label=(label if self.label_given else None))
        if logit.isnan().all():
            raise ValueError("NaN values in logits.")
        if adv_logit.isnan().all():
            raise ValueError("NaN values in logits.")

        bce_loss = self.loss_fn(logit, label.float(), idx)
        self.log(f"{stage}/bce_loss", bce_loss.detach(), on_step=False, on_epoch=True, logger=True)

        if self.fairret_strength != 0. and self.current_epoch >= self.epoch_start_fairret:
            simple_sens = self._gather_simple_sens(sens)
            simple_sens = torch.argmax(simple_sens, dim=1).long()
            fairret_loss = torch.nn.functional.cross_entropy(adv_logit, simple_sens)
            # fairret_loss = torch.nn.functional.binary_cross_entropy_with_logits(adv_logit, simple_sens)

            if fairret_loss.isnan():
                raise ValueError("NaN values in adversarial loss.")

            acc = torch.mean((torch.argmax(adv_logit, dim=1) == simple_sens).float())
            self.log(f"{stage}/adv_acc", acc, on_step=False, on_epoch=True, logger=True)

            self.log(f"{stage}/fairret_loss", fairret_loss.detach(), on_step=False, on_epoch=True, logger=True)

            if self.balancing_weight is not None:
                fairret_loss = self.balancing_weight * fairret_loss
            loss = bce_loss + self.fairret_strength * fairret_loss
        else:
            loss = bce_loss
        self.log(f"{stage}/loss", loss.detach(), on_step=False, on_epoch=True, prog_bar=True, logger=True)

        # Note: TorchMetrics are maintained separately and only aggregated at the end of the epoch.
        pred = torch.sigmoid(logit).detach()
        metrics = getattr(self, f"{stage}_metrics")
        metrics.update(pred, label)
        fair_metrics = getattr(self, f"{stage}_fair_metrics")
        fair_metrics.update(pred, feat, sens, label)
        return loss


class GradReverse(torch.autograd.Function):
    """
    borrwed from https://github.com/hanzhaoml/ICLR2020-CFair/blob/master/models.py
    Implement the gradient reversal layer for the convenience of domain adaptation neural network.
    The forward part is the identity function while the backward part is the negative function.
    """

    @staticmethod
    def forward(ctx, x):
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg()


def grad_reverse(x):
    return GradReverse.apply(x)
