import numpy as np
import torch 
from torchmetrics import Metric
import random
import torch

def add_label_bias(yclean, rho, theta_dict, seed=1359):
    """
    Add bias to labels based on sensitive attributes.
    
    theta_0_p: P(Y=+1|Z=-1,A=0)
    theta_0_m: P(Y=-1|Z=+1,A=0)
    theta_1_p: P(Y=+1|Z=-1,A=1)
    theta_1_m: P(Y=-1|Z=+1,A=1)
    """

    torch.manual_seed(seed)
    
    t_0_p, t_0_m, t_1_p, t_1_m = theta_dict['theta_0_p'], theta_dict['theta_0_m'], theta_dict['theta_1_p'], theta_dict['theta_1_m']
    

    g_01 = (rho == 0) & (yclean == 1)
    g_00 = (rho == 0) & (yclean == 0)
    g_11 = (rho == 1) & (yclean == 1)
    g_10 = (rho == 1) & (yclean == 0)
    
    group = [g_01, g_00, g_11, g_10]
    theta = [t_0_m, t_0_p, t_1_m, t_1_p]
    tilde_y = [0, 1, 0, 1]
    
    t = yclean.clone() 
    
    for i in range(len(group)):
        if group[i].numel() == 0: 
            continue
        for idx in group[i]:
            p = torch.rand(1).item()  
            if p < theta[i]:
                t[idx] = tilde_y[i]
            else:
                t[idx] = yclean[idx]

    return t




class FairnessMetricDDP(Metric):
    def __init__(self):
        super().__init__()
      
        self.add_state("correct_A0", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total_A0", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("correct_A1", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total_A1", default=torch.tensor(0.0), dist_reduce_fx="sum")

    def update(self, labels, preds, sensitive_att):
        correct = (preds == labels).float()

        mask_A0 = (sensitive_att == 0)
        self.correct_A0 += correct[mask_A0].sum()
        self.total_A0 += mask_A0.sum()

        mask_A1 = (sensitive_att == 1)
        self.correct_A1 += correct[mask_A1].sum()
        self.total_A1 += mask_A1.sum()

    def compute(self):
    
        acc_A0 = self.correct_A0 / self.total_A0 if self.total_A0 > 0 else torch.tensor(0.0)
        acc_A1 = self.correct_A1 / self.total_A1 if self.total_A1 > 0 else torch.tensor(0.0)
        
        fairness_measure = torch.abs(acc_A0 - acc_A1)
        return fairness_measure
    


class FairnessMetricDP(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("pred_A0", default=torch.tensor(0.0), dist_reduce_fx="sum")  
        self.add_state("total_A0", default=torch.tensor(0.0), dist_reduce_fx="sum") 
        self.add_state("pred_A1", default=torch.tensor(0.0), dist_reduce_fx="sum")  
        self.add_state("total_A1", default=torch.tensor(0.0), dist_reduce_fx="sum")  

    def update(self, labels, preds, sensitive_att):
        mask_A0 = (sensitive_att == 0)
        self.pred_A0 += (preds[mask_A0] == 1).float().sum() 
        self.total_A0 += mask_A0.sum()  
        mask_A1 = (sensitive_att == 1)
        self.pred_A1 += (preds[mask_A1] == 1).float().sum() 
        self.total_A1 += mask_A1.sum() 

    def compute(self):
        prob_A0 = self.pred_A0 / self.total_A0 if self.total_A0 > 0 else torch.tensor(0.0)
        prob_A1 = self.pred_A1 / self.total_A1 if self.total_A1 > 0 else torch.tensor(0.0)
        fairness_measure = torch.abs(prob_A1 - prob_A0)
        return fairness_measure
    



class FairnessMetricEO(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("tp_A0", default=torch.tensor(0.0), dist_reduce_fx="sum")  
        self.add_state("p_A0", default=torch.tensor(0.0), dist_reduce_fx="sum")  
        self.add_state("tp_A1", default=torch.tensor(0.0), dist_reduce_fx="sum")  
        self.add_state("p_A1", default=torch.tensor(0.0), dist_reduce_fx="sum")  

    def update(self, y_true, y_pred, sensitive_att):
        mask_positive = (y_true == 1)
        mask_A0 = (sensitive_att == 0) & mask_positive
        self.tp_A0 += ((y_pred == 1) & mask_A0).float().sum()  
        self.p_A0 += mask_A0.float().sum() 
        
   
        mask_A1 = (sensitive_att == 1) & mask_positive
        self.tp_A1 += ((y_pred == 1) & mask_A1).float().sum()  
        self.p_A1 += mask_A1.float().sum() 

    def compute(self):
        tpr_A0 = self.tp_A0 / self.p_A0 if self.p_A0 > 0 else torch.tensor(0.0)
        tpr_A1 = self.tp_A1 / self.p_A1 if self.p_A1 > 0 else torch.tensor(0.0)
        
        fairness_measure = torch.abs(tpr_A1 - tpr_A0)
        return fairness_measure
    
 


