import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
from torchmetrics import Accuracy
import warnings
from utils import FairnessMetricDDP, FairnessMetricDP, FairnessMetricEO
from rn import feature_neutralization

warnings.filterwarnings("ignore")


class DBMModule(pl.LightningModule):
    def __init__(self, 
                 pretrain_model,
                 target_model, 
                 hid_size,
                 learning_rate=1e-3, 
                 regularization_weight=1e-1):
        super().__init__()
        self.save_hyperparameters()
        self.pretrain_model = pretrain_model
        self.classifier = target_model
        self.softmax = nn.Softmax(dim=1)
        self.hid_size = hid_size
        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = Accuracy(task="binary")
        self.regularization_weight = regularization_weight
        self.learning_rate = learning_rate

        self.fairness_metrics = nn.ModuleDict({
            'DDP': FairnessMetricDDP(),
            'EO': FairnessMetricEO()
        })

    def forward(self, x):
        _,rep = self.pretrain_model(x)
        return self.classifier(rep)
    

    def training_step(self, batch, batch_idx):
        x, s, y = batch
        self.pretrain_model.eval()
        with torch.no_grad():
            pred, rep = self.pretrain_model(x)
            pred_softmax = self.softmax(pred)
        
        neutra_rep = feature_neutralization(rep, pred_softmax, y, s, self.hid_size)
        outputs = self.classifier(neutra_rep)
        outputs_A = outputs[s == 0]
        outputs_B = outputs[s == 1]
        y_A = y[s == 0]
        y_B = y[s == 1]
        loss = self.regularization_weight * self.criterion(outputs_A, y_A) + (1 - self.regularization_weight) * self.criterion(outputs_B, y_B)
        acc = self.accuracy(outputs.argmax(dim=1), y.int())
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
        
        return loss
    

    def validation_step(self, batch, batch_idx):
        x, s, y = batch
        outputs = self(x)
        loss = self.criterion(outputs, y)
        acc = self.accuracy(outputs.argmax(dim=1), y.int())
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', acc, on_epoch=True, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, s, y = batch
        outputs = self(x)
        loss = self.criterion(outputs, y)
        acc = self.accuracy(outputs.argmax(dim=1), y.int())
        self.log('test_loss', loss, on_epoch=True)
        self.log('test_acc', acc, on_epoch=True)
        preds = outputs.argmax(dim=1)
        for metric in self.fairness_metrics.values():
            metric.update(y, preds, s)

    def on_test_epoch_end(self):
        for name, metric in self.fairness_metrics.items():
            avg_fairness = metric.compute()
            self.log(name, avg_fairness)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.classifier.parameters(), lr=self.learning_rate)
        return optimizer
    