__author__ = 'Qi'
# Created by on 4/11/22.
import torch
import random
import torch.nn as nn


# Representation neutralization
def feature_neutralization(r_batch, p_batch, y_batch, a_batch, HIDDEN_DIM):
    category1_bias1 = []
    category1_bias2 = []
    category2_bias1 = []
    category2_bias2 = []

    for i in range(a_batch.shape[0]):
        if y_batch[i].cpu().detach().numpy() == 0 and a_batch[i].cpu().detach().numpy() == 0:
            category1_bias1.append([r_batch[i], p_batch[i]])
        elif y_batch[i].cpu().detach().numpy() == 0 and a_batch[i].cpu().detach().numpy() == 1:
            category1_bias2.append([r_batch[i], p_batch[i]])
        elif y_batch[i].cpu().detach().numpy() == 1 and a_batch[i].cpu().detach().numpy() == 0:
            category2_bias1.append([r_batch[i], p_batch[i]])
        elif y_batch[i].cpu().detach().numpy() == 1 and a_batch[i].cpu().detach().numpy() == 1:
            category2_bias2.append([r_batch[i], p_batch[i]])

    neutralization_repre_5 = torch.zeros(a_batch.shape[0], HIDDEN_DIM)
    neutralization_repre_6 = torch.zeros(a_batch.shape[0], HIDDEN_DIM)
    neutralization_repre_7 = torch.zeros(a_batch.shape[0], HIDDEN_DIM)
    neutralization_repre_8 = torch.zeros(a_batch.shape[0], HIDDEN_DIM)
    neutralization_repre_9 = torch.zeros(a_batch.shape[0], HIDDEN_DIM)

    neutralization_probability5 = torch.zeros(a_batch.shape[0], 2)

    for i in range(a_batch.shape[0]):
        if y_batch[i].cpu().detach().numpy() == 0 and a_batch[i].cpu().detach().numpy() == 0:
            if len(category1_bias2) != 0:
                neutralization_sample = random.choice(category1_bias2)
            else:
                neutralization_sample = random.choice(category1_bias1)


        elif y_batch[i].cpu().detach().numpy() == 0 and a_batch[i].cpu().detach().numpy() == 1:
            if len(category1_bias1) != 0:
                neutralization_sample = random.choice(category1_bias1)
            else:
                neutralization_sample = random.choice(category1_bias2)

        elif y_batch[i].cpu().detach().numpy() == 1 and a_batch[i].cpu().detach().numpy() == 0:
            if len(category2_bias2) != 0:
                neutralization_sample = random.choice(category2_bias2)
            else:
                neutralization_sample = random.choice(category2_bias1)

        elif y_batch[i].cpu().detach().numpy() == 1 and a_batch[i].cpu().detach().numpy() == 1:
            if len(category2_bias1) != 0:
                neutralization_sample = random.choice(category2_bias1)
            else:
                neutralization_sample = random.choice(category2_bias2)

        neutralization_repre_5[i] = 0.5 * r_batch[i] + 0.5 * neutralization_sample[0]
        neutralization_repre_6[i] = 0.6 * r_batch[i] + 0.4 * neutralization_sample[0]
        neutralization_repre_7[i] = 0.7 * r_batch[i] + 0.3 * neutralization_sample[0]
        neutralization_repre_8[i] = 0.8 * r_batch[i] + 0.2 * neutralization_sample[0]
        neutralization_repre_9[i] = 0.9 * r_batch[i] + 0.1 * neutralization_sample[0]

        neutralization_probability5[i] = 0.5 * p_batch[i] + 0.5 * neutralization_sample[1]

    neutralization_repre_5 = neutralization_repre_5.cuda()
    neutralization_repre_6 = neutralization_repre_6.cuda()
    neutralization_repre_7 = neutralization_repre_7.cuda()
    neutralization_repre_8 = neutralization_repre_8.cuda()
    neutralization_repre_9 = neutralization_repre_9.cuda()
    return neutralization_repre_5, neutralization_repre_6, neutralization_repre_7, neutralization_repre_8, neutralization_repre_9, neutralization_probability5


def DRO(loss, label, tau=3):
    exploss = torch.exp(loss / tau)
    p = exploss / torch.sum(exploss)
    p.detach_()

    return torch.sum(p * loss)


def class_balanced_attributes(loss, a_batch):
    loss = loss.view(-1, 1)
    # print(loss.size())
    m_ls = loss[a_batch == 1]
    f_ls = loss[a_batch == 0]

    return (torch.sum(m_ls) / len(m_ls) + torch.sum(f_ls) / len(f_ls)) / 2


def construct_neighbourhood_loss(r_batch, loss, y_batch, a_batch, neb_tau=1):

    # print(r_batch.size())
    # if torch.sum(torch.isnan(r_batch)) != 0:
    #     print("There exists ", torch.isnan(r_batch), " number of nan")
    #     r_batch[torch.isnan(r_batch)] = 0
    # print(r_batch)

    nm_r_batch = r_batch/ r_batch.norm(dim=1, keepdim=True)  # B*d
    nm_r_batch[torch.isnan(nm_r_batch)] = 1/nm_r_batch.size(1)

    loss = loss.view(-1, 1)

    # attributes of male is 1; attributes of female is 0;
    # calculate loss according to the loss and attributes
    pm_ls = loss[(y_batch == 1) & (a_batch == 1)]
    pf_ls = loss[(y_batch == 1) & (a_batch == 0)]
    nm_ls = loss[(y_batch == 0) & (a_batch == 1)]
    nf_ls = loss[(y_batch == 0) & (a_batch == 0)]

    # if len(pm_ls) == 0:
    #     print(" The length of positive male (pm) samples are 0")
    # if len(pf_ls) == 0:
    #     print(" The length of positive female (pf) samples are 0")
    #     print(len(pm_ls), len(nm_ls), len(nf_ls))
    # if len(nm_ls) == 0:
    #     print(" The length of nm_ls is 0")
    # if len(nf_ls) == 0:
    #     print(" The length of nf_ls is 0")

    # calculate similarity matrix according to the loss and attributes
    pm_rpt = nm_r_batch[(y_batch == 1) & (a_batch == 1)]  # B_pm * d
    pf_rpt = nm_r_batch[(y_batch == 1) & (a_batch == 0)]  # B_pf * d
    nm_rpt = nm_r_batch[(y_batch == 0) & (a_batch == 1)]  # B_nm * d
    nf_rpt = nm_r_batch[(y_batch == 0) & (a_batch == 0)]  # B_nf * d

    # calculates the NBA weights based on the samples that has the same class but different attributes
    pm_simat = pm_rpt.matmul(pf_rpt.T)  # B_pm * d  * d * B_pf
    pf_simat = pf_rpt.matmul(pm_rpt.T)  # B_pf * B_pm
    nm_simat = nm_rpt.matmul(nf_rpt.T)  # B_nm * B_nf
    nf_simat = nf_rpt.matmul(nm_rpt.T)  # B_nf * B_nm

    # calculates the neighbourho    ods based on the samples that has the same class but different attributes
    exp_pm_simat = torch.exp(pm_simat / neb_tau) / torch.sum(torch.exp(pm_simat / neb_tau), 1, keepdim=True)  # B_pm * B_pf
    exp_pf_simat = torch.exp(pf_simat / neb_tau) / torch.sum(torch.exp(pf_simat / neb_tau), 1, keepdim=True)  # B_pf * B_pm
    exp_nm_simat = torch.exp(nm_simat / neb_tau) / torch.sum(torch.exp(nm_simat / neb_tau), 1, keepdim=True)  # B_nm * B_nf
    exp_nf_simat = torch.exp(nf_simat / neb_tau) / torch.sum(torch.exp(nf_simat / neb_tau), 1, keepdim=True)  # B_nf * B_nm

    # print("a, b, c, d : ", torch.sum(exp_pm_simat.matmul(pf_ls))/exp_nf_simat.size(0) , torch.sum(exp_pf_simat.matmul(pm_ls))/exp_pf_simat.size(0) ,\
    #       torch.sum(exp_nm_simat.matmul(nf_ls))/exp_nf_simat.size(0),  torch.sum(exp_nf_simat.matmul(nm_ls))/exp_nf_simat.size(0), ">>>>>>>> : ")
    # print(exp_pf_simat.size(), pm_ls.size(), exp_pm_simat.size(), pf_ls.size(), exp_nm_simat.size(), nf_ls.size(), exp_nf_simat.size(), nm_ls.size())

    female_cnt = exp_nf_simat.size(0) + exp_pf_simat.size(0)
    male_cnt = exp_nm_simat.size(0) + exp_pm_simat.size(0)

    # print('Female : ', female_cnt, ': Male : ', male_cnt, pf_ls.size())
    if len(pf_ls) == 0 or len(nf_ls) == 0 or len(pm_ls) == 0 or len(nm_ls) == 0:
        mean_neb_loss = loss.mean()
    else:
        mean_neb_loss = torch.sum(exp_pm_simat.matmul(pf_ls)) / male_cnt + \
                        torch.sum(exp_pf_simat.matmul(pm_ls)) / female_cnt + \
                        torch.sum(exp_nm_simat.matmul(nf_ls)) / male_cnt + \
                        torch.sum(exp_nf_simat.matmul(nm_ls)) / female_cnt

    # print('>>>> :',  torch.sum(exp_pf_simat.matmul(pm_ls)).item()/ female_cnt + torch.sum(exp_nf_simat.matmul(nm_ls)).item()/ female_cnt)

    # mean_neb_loss = torch.sum(exp_pm_simat.matmul(pf_ls)) / ( female_cnt + male_cnt) + \
    #                  torch.sum(exp_pf_simat.matmul(pm_ls)) / ( female_cnt + male_cnt) + \
    #                  torch.sum(exp_nm_simat.matmul(nf_ls)) / ( female_cnt + male_cnt) + \
    #                  torch.sum(exp_nf_simat.matmul(nm_ls)) /( female_cnt + male_cnt)
    # if len(pf_ls) == 0:
    #     mean_neb_loss = loss.mean()
    #     print("mean:", mean_neb_loss)
    # #     mean_neb_loss = torch.sum(exp_nm_simat.matmul(nf_ls)) / male_cnt + \
    # #                 torch.sum(exp_nf_simat.matmul(nm_ls)) / female_cnt
    # #     # print("Length of positive female (pf_ls) samples are")
    # #     # print(male_cnt, female_cnt)
    # #     # print(torch.sum(pm_ls)/male_cnt, ',', male_cnt, female_cnt)
    # #     # print(torch.mean(exp_nm_simat.matmul(nf_ls)) / male_cnt + \
    # #     #             torch.sum(exp_nf_simat.matmul(nm_ls)) / female_cnt)
    # #     # female_cnt
    #      mean_neb_loss = torch.sum(exp_nm_simat.matmul(nf_ls)) / male_cnt + \
    #                  torch.sum(exp_nf_simat.matmul(nm_ls)) / female_cnt
    #     print(mean_neb_loss)
    #
    # print("mean_neb_loss", mean_neb_loss.item(), loss.mean().item())
    # mean_neb_loss
    # print('normal: ', mean_neb_loss , female_cnt, male_cnt)
    # print(torch.sum(exp_pm_simat.matmul(pf_ls)) / male_cnt, torch.sum(exp_pf_simat.matmul(pm_ls)) / female_cnt,  torch.sum(exp_nm_simat.matmul(nf_ls)) / male_cnt, torch.sum(exp_nf_simat.matmul(nm_ls))/female_cnt )
    # print( nm_ls.max(), nm_ls.min(), len(nm_ls), torch.mean(nf_rpt).item())

    #print(mean_neb_loss)
    return  0.3*mean_neb_loss + loss.mean()
    # return mean_neb_loss







def EOR_regularizer(logits, y_batch, a_batch):

    index_a0c0 =  (a_batch == 0) & (y_batch == 0)
    index_a0c1 =  (a_batch == 0) & (y_batch == 1)
    index_a1c0 = (a_batch == 1) & (y_batch == 0)
    index_a1c1 = (a_batch == 1) & (y_batch == 1)
    FPR = torch.sum((torch.mean(logits[index_a0c0], 0) - torch.mean(logits[index_a1c0], 0))**2)
    TPR = torch.sum((torch.mean(logits[index_a1c1], 0) - torch.mean(logits[index_a0c1], 0))**2)
    if torch.isnan(FPR):
        FPR = 0
    if torch.isnan(TPR):
        TPR = 0
    # print(sum(index_a0c0), sum(index_a0c1), sum(index_a1c0), sum(index_a1c1))
    EOR_regularizer =   FPR + TPR

    # print('TPR : ', TPR, ' | FPR : ', FPR, EOR_regularizer)
    return EOR_regularizer



class RAAN(nn.Module):
    """PyTorch module for the batch robust loss estimator"""


    def __init__(self, gamma = 0.5, data_size = 1000):
        '''
        :param threshold: margin for squred hinge loss
        '''
        super(RAAN, self).__init__()
        self.u_1 = torch.tensor([0.0]*data_size).cuda()
        self.u_2 = torch.tensor([0.0]*data_size).cuda()
        self.gamma = gamma
        print('RAAN criterion created.')


    def forward(self, index, r_batch, loss, y_batch, a_batch, neb_tau = 1):
        # print(r_batch.size())
        # if torch.sum(torch.isnan(r_batch)) != 0:
        #     print("There exists ", torch.isnan(r_batch), " number of nan")
        #     r_batch[torch.isnan(r_batch)] = 0
        # print(r_batch)

        nm_r_batch = r_batch / r_batch.norm(dim=1, keepdim=True)  # B*d
        nm_r_batch[torch.isnan(nm_r_batch)] = 1 / nm_r_batch.size(1)
        loss = loss.view(-1, 1)
        # print(' loss :', loss)
        # loss split
        # attributes of male is 1; attributes of female is 0;
        # calculate loss according to the loss and attributes
        pm_ls = loss[(y_batch == 1) & (a_batch == 1)]
        pf_ls = loss[(y_batch == 1) & (a_batch == 0)]
        nm_ls = loss[(y_batch == 0) & (a_batch == 1)]
        nf_ls = loss[(y_batch == 0) & (a_batch == 0)]


        # index split
        index_pm = index[(y_batch == 1) & (a_batch == 1)]
        index_pf = index[(y_batch == 1) & (a_batch == 0)]
        index_nm = index[(y_batch == 0) & (a_batch == 1)]
        index_nf = index[(y_batch == 0) & (a_batch == 0)]

        # representation split
        # calculate similarity matrix according to the loss and attributes
        pm_rpt = nm_r_batch[(y_batch == 1) & (a_batch == 1)]  # B_pm * d
        pf_rpt = nm_r_batch[(y_batch == 1) & (a_batch == 0)]  # B_pf * d
        nm_rpt = nm_r_batch[(y_batch == 0) & (a_batch == 1)]  # B_nm * d
        nf_rpt = nm_r_batch[(y_batch == 0) & (a_batch == 0)]  # B_nf * d

        # calculates the AAN weights based on the samples that has the same class but different attributes
        pm_simat = pm_rpt.matmul(pf_rpt.T)  # B_pm * d  * d * B_pf
        pf_simat = pf_rpt.matmul(pm_rpt.T)  # B_pf * B_pm
        nm_simat = nm_rpt.matmul(nf_rpt.T)  # B_nm * B_nf
        nf_simat = nf_rpt.matmul(nm_rpt.T)  # B_nf * B_nm

        # calculates the neighbourhoods based on the samples that has the same class but different attributes
        exp_pm_simat, exp_pf_simat = torch.exp(pm_simat / neb_tau),  torch.exp(pf_simat / neb_tau) # B_pm * B_pf, B_pf * B_pm
        exp_nm_simat, exp_nf_simat = torch.exp(nm_simat / neb_tau), torch.exp(nf_simat / neb_tau) # B_nm * B_nf, B_nf * B_nm
        female_cnt = exp_nf_simat.size(0) + exp_pf_simat.size(0)
        male_cnt = exp_nm_simat.size(0) + exp_pm_simat.size(0)

        # print('Female : ', female_cnt, ': Male : ', male_cnt, pf_ls.size())
        if len(pf_ls) == 0 or len(nf_ls) == 0 or len(pm_ls) == 0 or len(nm_ls) == 0:
            mean_raan_loss = loss.mean()
        else:

            gx1_pm, gx1_pf, gx1_nm, gx1_nf =  exp_pm_simat.matmul(pf_ls), exp_pf_simat.matmul(pm_ls), exp_nm_simat.matmul(nf_ls), exp_nf_simat.matmul(nm_ls)
            gx2_pm, gx2_pf, gx2_nm, gx2_nf =  torch.sum(torch.exp(pm_simat / neb_tau), 1,keepdim=True), \
                                             torch.sum(torch.exp(pf_simat / neb_tau), 1, keepdim=True) ,\
                                             torch.sum(torch.exp(nm_simat / neb_tau), 1, keepdim=True),\
                                             torch.sum(torch.exp(nf_simat / neb_tau), 1, keepdim=True)

            # group concatenation
            gx1 = torch.cat([gx1_pm, gx1_pf, gx1_nm, gx1_nf])
            gx2 = torch.cat([gx2_pm, gx2_pf, gx2_nm, gx2_nf])
            index = torch.cat([index_pm, index_pf, index_nm, index_nf])


            female_gx1 = torch.cat([gx1_pf, gx1_nf]).view(-1)
            female_gx2 = torch.cat([gx2_pf, gx2_nf]).view(-1)
            female_index = torch.cat([index_pf, index_nf])


            male_gx1 = torch.cat([gx1_pm, gx1_nm]).view(-1)
            male_gx2 = torch.cat([gx2_pm, gx2_nm]).view(-1)
            male_index = torch.cat([index_pm, index_nm])



            gx1, gx2 = gx1.view(-1), gx2.view(-1)
            # print('gx1 :', gx1.shape, 'gx2 :', gx2.shape)
            self.u_1[index] = (1 - self.gamma) * self.u_1[index] + self.gamma * gx1.detach()
            self.u_2[index] = (1 - self.gamma) * self.u_2[index] + self.gamma * gx2.detach()


            mean_raan_loss = torch.mean(female_gx1/self.u_2[female_index] - (self.u_1[female_index]*female_gx2)/(self.u_2[female_index]**2)) +\
                             torch.mean(male_gx1/self.u_2[male_index] - (self.u_1[male_index]*male_gx2)/(self.u_2[male_index]**2))


        return mean_raan_loss
