import torch
from torch import nn
from torch.distributions import Laplace
from torch.nn.functional import binary_cross_entropy_with_logits as cross_entropy
from lightning import LightningModule
from torchmetrics.classification import BinaryStatScores


class BinaryAccuracy(BinaryStatScores):
    is_differentiable: bool = False
    higher_is_better: bool = True
    full_state_update: bool = False

    def compute(self) -> torch.Tensor:
        tp, fp, tn, fn = self._final_state()
        return (tp + tn) / (tp + tn + fp + fn)


class BinaryBalancedAccuracy(BinaryStatScores):
    is_differentiable: bool = False
    higher_is_better: bool = True
    full_state_update: bool = False

    def compute(self) -> torch.Tensor:
        tp, fp, tn, fn = self._final_state()
        return 0.5 * (tp / (tp + fn) + tn / (tn + fp))


class MLP(nn.Sequential):
    def __init__(
        self,
        output_dim: int,
        hidden_dim: int = 64,
        depth: int = 4,
        activation: nn.Module = nn.SiLU(),
    ):
        layers = []
        for _ in range(depth - 1):
            layers.append(nn.LazyLinear(hidden_dim))
            layers.append(activation)
        layers.append(nn.LazyLinear(output_dim))
        super().__init__(*layers)


class FairClassifier(LightningModule):
    def __init__(
        self,
        fairness_weight: float = 0.5,
        encoded_dim: int = 1,
        hidden_dim: int = 64,
        encoder_depth: int = 4,
        head_depth: int = 4,
        normalize_encoding: bool = False,
        encoding_jitter: float = 0.0,
        warmup_epochs: int = 0,
        learning_rate: float = 3e-4,
        weight_decay: float = 1e-4,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.save_hyperparameters({"model": self.__class__.__name__})

        # model components
        self.encoder = MLP(encoded_dim, hidden_dim, encoder_depth)
        self.head_task = MLP(1, hidden_dim, head_depth)
        self.adv_mlp = MLP(1, hidden_dim, head_depth)
        self.adv_lin = nn.LazyLinear(1)

        # metrics
        self.train_accuracy_task = BinaryAccuracy()
        self.val_accuracy_task = BinaryAccuracy()
        self.test_accuracy_task = BinaryAccuracy()

        self.train_accuracy_adv_lin = BinaryBalancedAccuracy()
        self.val_accuracy_adv_lin = BinaryBalancedAccuracy()
        self.test_accuracy_adv_lin = BinaryBalancedAccuracy()

        self.train_accuracy_adv_mlp = BinaryBalancedAccuracy()
        self.val_accuracy_adv_mlp = BinaryBalancedAccuracy()
        self.test_accuracy_adv_mlp = BinaryBalancedAccuracy()

    def configure_optimizers(self):
        lr = self.hparams["learning_rate"]
        wd = self.hparams["weight_decay"]
        optimizer = torch.optim.AdamW(
            [
                dict(params=self.encoder.parameters(), lr=lr, weight_decay=wd),
                dict(params=self.head_task.parameters(), lr=lr, weight_decay=wd),
                dict(params=self.adv_mlp.parameters(), lr=lr, weight_decay=wd),
                dict(params=self.adv_lin.parameters(), lr=1e-2, weight_decay=0.0),
            ],
        )
        return optimizer

    def forward(self, x):
        return self.head_task(self.encode(x))

    def encode(self, x):
        x = self.encoder(x)
        if self.hparams["encoding_jitter"] is not None:
            x += self.hparams["encoding_jitter"] * torch.randn_like(x)
        if self.hparams["normalize_encoding"]:
            x = x / x.norm(dim=-1, keepdim=True)
        return x

    def fairness_penalty(self, x_encoded, protected):
        # characteristic function penalty
        dim = x_encoded.shape[-1]
        samples = min(1024, 64 * 2**dim)  # TODO make this a perparameter ?
        std = 3.0  # TODO make this a perparameter ?
        t = std * Laplace(0, 1).sample((samples, dim)).to(x_encoded.device)

        xi = [x_encoded[protected == i] for i in torch.unique(protected)]
        phi_target = torch.exp(-0.5 * torch.sum(t**2, dim=-1))  # normal distribution
        phi_empirical = torch.stack([torch.mean(torch.exp(1j * (t @ x.T)), dim=-1) for x in xi])
        penalty = torch.abs(phi_empirical - phi_target) ** 2
        return penalty.mean()

    def training_step(self, batch, batch_idx):
        # forward pass
        x, protected, labels = batch
        x_encoded = self.encode(x)
        logits_task = self.head_task(x_encoded).squeeze(-1)
        logits_adv_lin = self.adv_lin(x_encoded.detach()).squeeze(-1)
        logits_adv_mlp = self.adv_mlp(x_encoded.detach()).squeeze(-1)

        # loss
        loss_fairness = self.fairness_penalty(x_encoded, protected)
        loss_task = cross_entropy(logits_task, labels.float())
        loss_adv_lin = cross_entropy(logits_adv_lin, protected.float())
        loss_adv_mlp = cross_entropy(logits_adv_mlp, protected.float())
        w = self.hparams["fairness_weight"] * min(
            1.0, self.current_epoch / (1 + self.hparams["warmup_epochs"])
        )
        loss = (1 - w) * loss_task + w * loss_fairness + loss_adv_lin + loss_adv_mlp

        # logging
        self.log("train/loss_total", loss, on_step=True)
        self.log("train/loss_fairness", loss_fairness, on_step=True)
        self.log("train/loss_task", loss_task, on_step=True)
        self.log("train/loss_adv_lin", loss_adv_lin, on_step=True)
        self.log("train/loss_adv_mlp", loss_adv_mlp, on_step=True)

        self.train_accuracy_task.update(logits_task, labels)
        self.train_accuracy_adv_lin.update(logits_adv_lin, protected)
        self.train_accuracy_adv_mlp.update(logits_adv_mlp, protected)
        self.log("train/accuracy_task", self.train_accuracy_task, on_epoch=True)
        self.log("train/accuracy_adv_lin", self.train_accuracy_adv_lin, on_epoch=True)
        self.log("train/accuracy_adv_mlp", self.train_accuracy_adv_mlp, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        # forward pass
        x, protected, labels = batch
        x_encoded = self.encode(x)
        logits_task = self.head_task(x_encoded).squeeze(-1)
        logits_adv_lin = self.adv_lin(x_encoded.detach()).squeeze(-1)
        logits_adv_mlp = self.adv_mlp(x_encoded.detach()).squeeze(-1)

        # logging
        self.val_accuracy_task.update(logits_task, labels)
        self.val_accuracy_adv_lin.update(logits_adv_lin, protected)
        self.val_accuracy_adv_mlp.update(logits_adv_mlp, protected)
        self.log("val/accuracy_task", self.val_accuracy_task, on_epoch=True)
        self.log("val/accuracy_adv_lin", self.val_accuracy_adv_lin, on_epoch=True)
        self.log("val/accuracy_adv_mlp", self.val_accuracy_adv_mlp, on_epoch=True)

    def test_step(self, batch, batch_idx):
        # forward pass
        x, protected, labels = batch
        x_encoded = self.encode(x)
        logits_task = self.head_task(x_encoded).squeeze(-1)
        logits_adv_lin = self.adv_lin(x_encoded.detach()).squeeze(-1)
        logits_adv_mlp = self.adv_mlp(x_encoded.detach()).squeeze(-1)

        # logging
        self.test_accuracy_task.update(logits_task, labels)
        self.test_accuracy_adv_lin.update(logits_adv_lin, protected)
        self.test_accuracy_adv_mlp.update(logits_adv_mlp, protected)
        self.log("test/accuracy_task", self.test_accuracy_task, on_epoch=True)
        self.log("test/accuracy_adv_lin", self.test_accuracy_adv_lin, on_epoch=True)
        self.log("test/accuracy_adv_mlp", self.test_accuracy_adv_mlp, on_epoch=True)


class FairClassifierSimple(FairClassifier):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.head_task = nn.LazyLinear(1)

    def fairness_penalty(self, x_encoded, protected):
        # moment penalty
        xi = [x_encoded[protected == i] for i in torch.unique(protected)]
        xi = [x for x in xi if len(x) > 1]
        mu = torch.stack([x.mean(0) for x in xi])
        if x_encoded.shape[-1] > 1:
            cov = torch.stack([x.T.cov().diag() for x in xi])
        else:
            cov = torch.stack([x.T.cov() for x in xi])
        penalty = mu**2 + cov - 1.0 - torch.log(cov + 1e-8)
        return penalty.sum()
