import torch
import statistics
from tool.logger import *


def Experiment_Model_testing(device, testing_dataloader, mask_s1_flag, testing_model, hypothesis, only_acc=False):
    acc_numerator = 0
    acc_denominator = 0

    num_s1_pred1 = 0
    num_s1_pred0 = 0
    num_s0_pred1 = 0
    num_s0_pred0 = 0

    num_s1_pred1_y1 = 0
    num_s1_pred1_y0 = 0
    num_s0_pred1_y1 = 0
    num_s0_pred1_y0 = 0

    num_s1_y1 = 0
    num_s1_y0 = 0
    num_s0_y1 = 0
    num_s0_y0 = 0

    # Model testing
    testing_model.eval()
    with torch.no_grad():
        for batch_index, batch in enumerate(testing_dataloader):
            X = batch["X"].to(device)
            y = batch["y"].to(device)
            tmp = testing_model(X).to(device)
            if hypothesis == "LR":
                prediction = (tmp >= 0.5).reshape(-1)
            else:
                prediction = torch.softmax(tmp, 1).argmax(dim=1)
            acc_numerator += sum(prediction.eq(y))
            acc_denominator += X.shape[0]

            if not only_acc:
                if mask_s1_flag:
                    s = batch["s2"]
                else:
                    s = batch["s1"]

                y_0 = (y == 0).int().reshape(-1).to(device)
                y_1 = (y == 1).int().reshape(-1).to(device)
                s_1 = (s == 1).int().to(device)
                s_0 = (s == 0).int().to(device)
                pred_1 = (prediction == 1).int().to(device)
                pred_0 = (prediction == 0).int().to(device)

                num_s1_pred1 += (s_1 * pred_1).sum().to(device)
                num_s1_pred0 += (s_1 * pred_0).sum().to(device)
                num_s0_pred1 += (s_0 * pred_1).sum().to(device)
                num_s0_pred0 += (s_0 * pred_0).sum().to(device)

                num_s1_pred1_y1 += (s_1 * pred_1 * y_1).sum().to(device)
                num_s1_pred1_y0 += (s_1 * pred_1 * y_0).sum().to(device)
                num_s0_pred1_y1 += (s_0 * pred_1 * y_1).sum().to(device)
                num_s0_pred1_y0 += (s_0 * pred_1 * y_0).sum().to(device)

                num_s1_y1 += (s_1 * y_1).sum().to(device)
                num_s1_y0 += (s_1 * y_0).sum().to(device)
                num_s0_y1 += (s_0 * y_1).sum().to(device)
                num_s0_y0 += (s_0 * y_0).sum().to(device)

    acc = acc_numerator / acc_denominator
    # logger.info(f"Testing model acc: {acc}")

    if not only_acc:
        a = num_s0_pred1_y1 / num_s0_y1
        b = num_s1_pred1_y1 / num_s1_y1
        # logger.info(f"P(y^ = 1 | s = 0, y=1) = {a} , P(y^ = 1 | s = 1, y=1) = {b} ")

        # This definition is copy from Renyi
        DEO = abs( a - b )
        # logger.info(f"Difference of Equality of Opportunity violation (DEO): {DEO}")

        # This definition is copy from FairFed
        EOD = a - b
        # logger.info(f"Equal Opportunity Difference (EOD): {EOD}")

        # This definition is copy from FairFed
        SPD = (num_s0_pred1/(num_s0_y1+num_s0_y0)) - (num_s1_pred1/(num_s1_y1+num_s1_y0))
        # logger.info(f"Statistical Parity Difference (SPD): {SPD}")

        # This definition is copy from FedFair
        FR = 1 - DEO
        # logger.info(f"Fairness measurement (FR): {FR}")

        HM = statistics.harmonic_mean([float(acc), float(FR)])
        # logger.info(f"Harmonic Mean of Fairness and Accuracy (HM): {HM}")

        return acc, DEO, EOD, SPD, FR, HM
    else:
        return acc

