import numpy as np
import time
import torch
import torch.nn as nn
import torch.optim as optim
from fedlearn.models.models import choose_model
from torch.utils.data import DataLoader
from fedlearn.models.FairBatchSampler import FairBatch, CustomDataset
from fedlearn.utils.model_utils import weighted_loss
from fedlearn.utils.model_utils import get_sort_idxs, get_cdf, get_sample_target
import torch.nn.functional as F
import copy
import cvxpy as cp
from fedlearn.models.client import Client

class CGFFedClient(Client):

    def __init__(self, cid, train_data, val_data, test_data, options={}, model=None):

        super().__init__(cid, train_data, val_data, test_data, options, model)
        self.fairness_constraints = options['fairness_constraints']
        self.fair_metric = self.fairness_constraints['fairness_measure']
        self.global_fair_constraint = self.fairness_constraints['global']
        self.local_fair_constraint = self.fairness_constraints['local']



        self.Alabel = [int(a) for a in self.data_info['Alabel']]
        self.Ylabel = [int(y) for y in self.data_info['Ylabel']]
        if self.options['fairness_type'] == 'groupwise':
            if self.fair_metric == 'DP':
                self.global_lamb = torch.tensor([0.1, 0.1], requires_grad=False)
                self.local_mu = torch.tensor([0.1, 0.1], requires_grad=False)
            if self.fair_metric == 'EO':
                self.global_lamb = torch.tensor([[0.1, 0.1],
                                                 [0.1, 0.1]], requires_grad=False) # first row for Y=1, second for Y=0
                # self.global_lamb_2 = torch.tensor([0.1, 0.1], requires_grad=False) # for Y=0
                self.local_mu = torch.tensor([[0.1, 0.1],
                                                 [0.1, 0.1]], requires_grad=False) # first row for Y=1, second for Y=0
                # self.local_mu[1] = torch.tensor([0.1, 0.1], requires_grad=False)
        elif self.options['fairness_type'] == 'subgroup':
            if self.fair_metric == 'DP':
                self.global_lamb_1 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.1
                self.global_lamb_2 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.1
                self.local_mu_1 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.1
                self.local_mu_2 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.1
            elif self.fair_metric == 'EOP':
                self.global_lamb_1 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.1
                self.global_lamb_2 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.1
                self.local_mu_1 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.1
                self.local_mu_2 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.1


        self.ensemble_weight_global = 0.5
        self.train_A_info = self.data_info['train_A_info']

        self.class_num = len(list(self.Ylabel))
        self.train_client_A_info = self.data_info['train_client_A_info'][self.cid]

        self.p_a_k = {a: self.train_client_A_info[a] / self.data_info['train_num'] for a in self.Alabel}
        self.p_k_given_a = {a: self.train_client_A_info[a] / self.train_A_info[a] for a in self.Alabel}

        self.train_A_Y_info = self.data_info['train_A_Y_info']
        self.train_client_A_Y_info = self.data_info['train_client_A_Y_info'][self.cid]
        self.p_a_y_k = {a: {y: self.train_client_A_Y_info[a][y] / self.data_info['train_num'] for y in self.Ylabel} for a in self.Alabel}
        self.p_a_y = {a:{y: self.train_A_Y_info[a][y] / self.data_info['train_num'] for y in self.Ylabel} for a in self.Alabel}
        self.p_y_k = {y:sum([self.train_client_A_Y_info[a][y] for a in self.Alabel]) / self.data_info['train_num'] for y in self.Ylabel}
        self.p_y= {y:sum([self.train_A_Y_info[a][y] for a in self.Alabel]) / self.data_info['train_num'] for y in self.Ylabel}

        self.ensemble_params_lr = self.options['ensemble_params_lr']
        self.dual_params_lr = self.options['dual_params_lr']
        self.dual_params_bound = self.options['dual_params_bound']

        # CGFFed model
        self.personal_model = copy.deepcopy(model)
        self.M_a_k = {a : torch.zeros(self.class_num, self.class_num) for a in self.Alabel}

        # optimizer
        if options['local_optimizer'].lower() == 'sgd':
            self.personal_optimizer = optim.SGD(self.personal_model.parameters(), lr=self.local_lr, weight_decay=self.wd)
        elif options['local_optimizer'].lower() == 'adam':
            self.personal_optimizer = optim.Adam(self.personal_model.parameters(), lr=self.local_lr, weight_decay=self.wd)
        elif options['local_optimizer'].lower() == 'adagrad':
            self.personal_optimizer = optim.adagrad(self.personal_model.parameters(), lr=self.local_lr, weight_decay=self.wd)

                
    @staticmethod
    def move_model_to_gpu(model, options):
        if 'gpu' in options and (options['gpu'] is True):
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if 'device' not in options else options['device']
            model.to(device)
            print('>>> Use gpu on self.device {}'.format(device.index))
        else:
            print('>>> Do not use gpu')

    def set_params(self, flat_params):
        '''set model parameters, where input is a flat parameter'''
        self.model.set_params(flat_params)

    def get_model_params(self):
        '''get local flat model parameters, transform torch model parameters into flat tensor'''
        return self.model.get_flat_params()
    
    def get_global_params(self, global_params):
        self.global_params = copy.deepcopy(global_params)

        
    def get_grads(self, mini_batch_data):
        '''get model gradient'''
        x, y = mini_batch_data
        self.model.train()
        if self.gpu:
            x, y = x.to(self.device), y.to(self.device)
        self.optimizer.zero_grad()
        pred = self.model(x)
        loss = self.criterion(pred, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 50)
        flat_grads = self.model.get_flat_grads().cpu().detach()
        # self.optimizer.zero_grad()
        return torch.empty_like(flat_grads).copy_(flat_grads), loss.cpu().detach()

    def get_pred(self):
        self.model.eval()
        dataloader = DataLoader(self.train_data, batch_size = self.batch_size, shuffle = False)
        pred_score = predicted = torch.tensor([])
        with torch.no_grad():
            for i, (x, y, a) in enumerate(dataloader):
                if self.gpu:
                    x, y = x.to(self.device), y.to(self.device)
                self.optimizer.zero_grad()
                pred_score_batch = self.model(x).detach().cpu()
                if self.mission == 'binary':
                    predicted_batch = ((torch.sign(pred_score_batch - 0.5) + 1) / 2)
                elif self.mission == 'multiclass':
                    _, predicted_batch = torch.max(pred_score_batch, 1)
                pred_score = torch.cat([pred_score, pred_score_batch.squeeze()])
                predicted = torch.cat([predicted, predicted_batch.squeeze()])
        return pred_score.reshape(-1,1).clone(), predicted.reshape(-1,1).clone(), self.train_data.A
    
    def local_train(self):

        begin_time = time.time()

        for _ in range(self.num_local_round):
            self.model.train()
            (x, y, a) = self.get_next_train_batch()
            if self.gpu:
                x, y = torch.squeeze(x.to(self.device)), y.to(self.device).reshape(-1,1)
            self.optimizer.zero_grad()
            pred = self.model(x)
            loss = self.criterion(pred, y.view(-1).long())
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 60)
            self.optimizer.step() 
        # print(f"\n learnig rate: {self.optimizer.param_groups[0]['lr']}")
        # self.scheduler.step()

        self.local_model = self.get_model_params()

        train_stats = self.model_eval(self.train_dataloader)
        param_dict = {'norm': torch.norm(self.local_model).item(),
            'max': self.local_model.max().item(),
            'min': self.local_model.min().item()}
        
        return_dict = {'loss': train_stats['loss'] / train_stats['num'],
            'acc': train_stats['acc'] / train_stats['num']}
        return_dict.update(param_dict)

        end_time = time.time()
        stats = {'id': self.cid, 'time': round(end_time - begin_time, 2)}
        stats.update(return_dict)
        return (len(self.train_data), self.local_model), stats
    
    def get_cal_Matrix(self, lamb, mu):
        if self.options['fairness_type'] == 'groupwise':
            if self.fair_metric == 'DP':
                assert len(lamb) == 2
                assert len(mu) == 2
                M_a_k = {a : torch.zeros(self.class_num, self.class_num) for a in self.Alabel}
                D_g_k = {a: (2*a-1)*torch.tensor([[0, self.p_k_given_a[a]],
                                                  [0, self.p_k_given_a[a]]]).clone() for a in self.Alabel}
                D_l_k = {a: (2*a-1)*torch.tensor([[0, 1],
                                                  [0, 1]]).clone() for a in self.Alabel}
                for a in self.Alabel:
                    M_a_k[a] = torch.eye(self.class_num, self.class_num) - 1 / float(self.p_a_k[a]) * (lamb[0] - lamb[1]) *  D_g_k[a] - 1 / float(self.p_a_k[a]) * (mu[0] - mu[1]) *  D_l_k[a]

                # positive calibration
                global_min = min([M.min().item() for M in M_a_k.values()])
                if global_min < 0:
                    shift = -global_min + 0.001
                    for key, M in M_a_k.items():
                        M_a_k[key] = M + shift
                
                # print(f'calibrated matrix: {M_a_k}')
                return M_a_k
            if self.fair_metric == 'EO':
                assert lamb.shape == (2, 2)
                assert mu.shape == (2, 2)
                M_a_k = {a : torch.zeros(self.class_num, self.class_num) for a in self.Alabel}
                # for Y = 1
                D_g_k_1 = {a: (2*a-1)*torch.tensor([[0, 0                                ],
                                                    [0, self.p_a_k[a] / self.p_a_y[a][1] ]]).clone() for a in self.Alabel}
                # for Y = 0
                D_g_k_2 = {a: (2*a-1)*torch.tensor([[0, self.p_a_k[a] / self.p_a_y[a][0] ],
                                                    [0, 0                                ]]).clone() for a in self.Alabel}
                
                # for Y = 1
                D_l_k_1 = {a: (2*a-1)*torch.tensor([[0, 0                                ],
                                                    [0, self.p_a_k[a] / self.p_a_y_k[a][1] ]]).clone() for a in self.Alabel}
                # for Y = 0
                D_l_k_2 = {a: (2*a-1)*torch.tensor([[0, self.p_a_k[a] / self.p_a_y_k[a][0] ],
                                                    [0, 0                                ]]).clone() for a in self.Alabel}
                for a in self.Alabel:
                    M_a_k[a] = ( torch.eye(self.class_num, self.class_num) - 1 / float(self.p_a_k[a]) * (lamb[0][0] - lamb[0][1]) *  D_g_k_1[a]
                                                                           - 1 / float(self.p_a_k[a]) * (lamb[1][0] - lamb[1][1]) *  D_g_k_2[a]
                                                                           - 1 / float(self.p_a_k[a]) * (mu[0][0] - mu[0][1]) *  D_l_k_1[a] 
                                                                           - 1 / float(self.p_a_k[a]) * (mu[1][0] - mu[1][1]) *  D_l_k_2[a] )
                
                # positive calibration
                global_min = min([M.min().item() for M in M_a_k.values()])
                if global_min < 0:
                    shift = -global_min + 0.001
                    for key, M in M_a_k.items():
                        M_a_k[key] = M + shift
                
                # print(f'calibrated matrix: {M_a_k}')
                return M_a_k
        if self.options['fairness_type'] == 'subgroup':
            assert lamb[0].shape == (len(self.Alabel), len(self.Ylabel)) # lamb[0] = lamb_1, lamb[1] = lamb_2
            assert mu[0].shape == (len(self.Alabel), len(self.Ylabel)) # mu[0] = mu_1, mu[1] = mu_2
            if self.fair_metric == 'DP':
                M_a_k = {a : torch.zeros(self.class_num, self.class_num) for a in self.Alabel}

                D_g_k = {a_prime: {y: {aa: torch.zeros(self.class_num, self.class_num).clone() for aa in self.Alabel} for y in self.Ylabel} for a_prime in self.Alabel}
                for a_prime in self.Alabel:
                    for y in self.Ylabel:
                        for aa in self.Alabel:
                            D_g_k[a_prime][y][aa][:,y] = self.p_k_given_a[aa] - self.p_a_k[aa] if aa == a_prime else - self.p_a_k[aa]

                D_l_k = {a_prime: {y: {aa: torch.zeros(self.class_num, self.class_num).clone() for aa in self.Alabel} for y in self.Ylabel} for a_prime in self.Alabel}

                for a_prime in self.Alabel:
                    for y in self.Ylabel:
                        for aa in self.Alabel:
                            D_l_k[a_prime][y][aa][:,y] = 1 - self.p_k_given_a[aa] if aa == a_prime else - self.p_k_given_a[aa]

                for a in self.Alabel:
                    M_a_k[a] = ( torch.eye(self.class_num, self.class_num) - 1 / float(self.p_a_k[a]) * sum((lamb[0][a_prime, y] - lamb[1][a_prime, y]) * D_g_k[a_prime][y][a] for a_prime in self.Alabel for y in self.Ylabel)
                                                                           - 1 / float(self.p_a_k[a]) * sum((mu[0][a_prime, y]   - mu[1][a_prime, y])   * D_l_k[a_prime][y][a] for a_prime in self.Alabel for y in self.Ylabel))
                
                # positive calibration
                global_min = min([M.min().item() for M in M_a_k.values()])
                if global_min < 0:
                    shift = -global_min + 0.001
                    for key, M in M_a_k.items():
                        M_a_k[key] = M + shift

                # print(f'calibrated matrix: {M_a_k}')
                return M_a_k
            
            elif self.fair_metric == 'EOP':
                M_a_k = {a : torch.zeros(self.class_num, self.class_num) for a in self.Alabel}

                D_g_k = {a_prime: {y: {aa: torch.zeros(self.class_num, self.class_num).clone() for aa in self.Alabel} for y in self.Ylabel} for a_prime in self.Alabel}
                for a_prime in self.Alabel:
                    for y in self.Ylabel:
                        for aa in self.Alabel:
                            D_g_k[a_prime][y][aa][y,y] = self.p_a_k[aa]/self.p_a_y[aa][y] - self.p_a_k[aa]/self.p_y[y] if aa == a_prime else - self.p_a_k[aa]/self.p_y[y]
                
                D_l_k = {a_prime: {y: {aa: torch.zeros(self.class_num, self.class_num).clone() for aa in self.Alabel} for y in self.Ylabel} for a_prime in self.Alabel}

                for a_prime in self.Alabel:
                    for y in self.Ylabel:
                        for aa in self.Alabel:
                            D_l_k[a_prime][y][aa][y,y] = self.p_a_k[aa]/self.p_a_y_k[aa][y] - self.p_a_k[aa]/self.p_y_k[y] if aa == a_prime else - self.p_a_k[aa]/self.p_y_k[y]
                
                for a in self.Alabel:
                    M_a_k[a] = ( torch.eye(self.class_num, self.class_num) - 1 / float(self.p_a_k[a]) * sum((lamb[0][a_prime, y] - lamb[1][a_prime, y]) * D_g_k[a_prime][y][a] for a_prime in self.Alabel for y in self.Ylabel)
                                                                           - 1 / float(self.p_a_k[a]) * sum((mu[0][a_prime, y]   - mu[1][a_prime, y])   * D_l_k[a_prime][y][a] for a_prime in self.Alabel for y in self.Ylabel))
                
                # positive calibration
                global_min = min([M.min().item() for M in M_a_k.values()])
                if global_min < 0:
                    shift = -global_min + 0.001
                    for key, M in M_a_k.items():
                        M_a_k[key] = M + shift

                # print(f'calibrated matrix: {M_a_k}')
                return M_a_k
                
            
    def get_calibrated_loss(self, model, M):
        dataloader = DataLoader(self.train_data, batch_size = self.batch_size, shuffle = False)
        total_loss = 0
        sample_num = 0
        M_tensor = torch.stack([M[aa] for aa in self.Alabel], dim=0).to(self.device) # (a, m, m)
        model.eval()
        with torch.no_grad():
            for i, (x, y, a) in enumerate(dataloader):
                if self.gpu:
                    x, y, a = x.to(self.device), y.to(self.device), a.to(self.device)
                if self.mission == 'binary':
                    pred = model(x)
                elif self.mission == 'multiclass':
                    pred = model(x)
                    outputs = - F.log_softmax(pred, dim=1)
                n, m = outputs.shape
                M_sel = M_tensor[a.view(-1).long()] # (n, m, m)
                rows = M_sel[torch.arange(n), y.view(-1).long()] # (n, m)
                batch_loss = torch.sum(outputs * rows, dim=1)  # (n,)
                total_loss += batch_loss.sum().detach().cpu().item()
                sample_num += n
        return {'average_loss': total_loss/sample_num, 'total_loss':total_loss,'sample_num':sample_num}
    

    def CGF_train(self):

        begin_time = time.time()
        # already set the parameter global_lamb, model parameter

        # update M
        if self.options['fairness_type'] == 'groupwise':
            self.M_a_k = self.get_cal_Matrix(self.global_lamb, self.local_mu)
        elif self.options['fairness_type'] == 'subgroup':
            self.M_a_k = self.get_cal_Matrix([self.global_lamb_1,self.global_lamb_2], [self.local_mu_1,self.local_mu_2])

        # update ensemble params
        global_loss = self.get_calibrated_loss(self.model, self.M_a_k)
        local_loss = self.get_calibrated_loss(self.personal_model, self.M_a_k)
        w1 = self.ensemble_weight_global  * torch.exp(-torch.tensor(self.ensemble_params_lr * global_loss['average_loss']))
        w2 = (1-self.ensemble_weight_global) * torch.exp(-torch.tensor(self.ensemble_params_lr * local_loss['average_loss']))
        assert w1 >= 0
        assert w2 >= 0
        self.ensemble_weight_global = w1 / (w1 + w2)
        # print(f'ensemble parameter: {self.ensemble_weight_global}')

        # update dual parameters
        ## update local mu
        local_test = self.model_eval(self.train_dataloader, CFG=True)
        self.local_fair_confusion_matrix = local_test['local_confusion']
        if self.options['fairness_type'] == 'groupwise':
            if self.fair_metric == 'DP':
                local_fair_measure = local_test['local_fair_org']['DP']
                self.local_mu[0] += self.dual_params_lr * (local_fair_measure - self.local_fair_constraint)
                self.local_mu[1] += self.dual_params_lr * (- local_fair_measure - self.local_fair_constraint)
                self.local_mu = torch.clamp(self.local_mu, min=0.0, max=self.dual_params_bound)
            if self.fair_metric == 'EO':
                local_fair_measure = local_test['local_fair_org']['EO'] # local_EO_org = [local_EO_Y1,local_EO_Y0]

                self.local_mu[0][0] += self.dual_params_lr * (local_fair_measure[0] - self.local_fair_constraint)
                self.local_mu[0][1] += self.dual_params_lr * (- local_fair_measure[0] - self.local_fair_constraint)
                # self.local_mu[0] = torch.clamp(self.local_mu[0], min=0.0, max=self.dual_params_bound)

                self.local_mu[1][0] += self.dual_params_lr * (local_fair_measure[1] - self.local_fair_constraint)
                self.local_mu[1][1] += self.dual_params_lr * (- local_fair_measure[1] - self.local_fair_constraint)
                self.local_mu = torch.clamp(self.local_mu, min=0.0, max=self.dual_params_bound)
        if self.options['fairness_type'] == 'subgroup':
            if self.fair_metric == 'DP':
                local_fair_measure = local_test['local_fair_org']['DP']
                for a_prime in self.Alabel:
                    for y in self.Ylabel:
                        self.local_mu_1[a_prime][y] += self.dual_params_lr * (local_fair_measure[a_prime][y]- self.local_fair_constraint)
                        self.local_mu_2[a_prime][y] += self.dual_params_lr * (- local_fair_measure[a_prime][y]- self.local_fair_constraint)
                self.local_mu_1 = torch.clamp(self.local_mu_1, min=0.0, max=self.dual_params_bound)
                self.local_mu_2 = torch.clamp(self.local_mu_2, min=0.0, max=self.dual_params_bound)
            if self.fair_metric == 'EOP':
                local_fair_measure = local_test['local_fair_org']['EOP']
                for a_prime in self.Alabel:
                    for y in self.Ylabel:
                        self.local_mu_1[a_prime][y] += self.dual_params_lr * (local_fair_measure[a_prime][y]- self.local_fair_constraint)
                        self.local_mu_2[a_prime][y] += self.dual_params_lr * (- local_fair_measure[a_prime][y]- self.local_fair_constraint)
                self.local_mu_1 = torch.clamp(self.local_mu_1, min=0.0, max=self.dual_params_bound)
                self.local_mu_2 = torch.clamp(self.local_mu_2, min=0.0, max=self.dual_params_bound)
        # print(f'local mu of client {self.cid}: {self.local_mu}, local fair value: {local_fair_measure}')

        M_tensor = torch.stack([self.M_a_k[aa] for aa in self.Alabel], dim=0).to(self.device) # (a, m, m)
        
        self.model.eval()
        for _ in range(self.num_local_round):
            self.personal_model.train()
            (x, y, a) = self.get_next_train_batch()
            if self.gpu:
                x, y, a = torch.squeeze(x.to(self.device)), y.to(self.device), a.to(self.device)
            self.personal_optimizer.zero_grad()
            outputs = - F.log_softmax(self.personal_model(x), dim=1)
            n, m = outputs.shape
            M_sel = M_tensor[a.view(-1).long()] # (n, m, m)
            rows = M_sel[torch.arange(n), y.view(-1).long()] # (n, m)
            loss = torch.sum(outputs * rows)  
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 50)
            self.personal_optimizer.step() 

        self.personal_model.eval()
        for _ in range(self.num_local_round):
            self.model.train()
            (x, y, a) = self.get_next_train_batch()
            if self.gpu:
                x, y, a = torch.squeeze(x.to(self.device)), y.to(self.device), a.to(self.device)
            self.optimizer.zero_grad()
            outputs = - F.log_softmax(self.model(x), dim=1)
            n, m = outputs.shape
            M_sel = M_tensor[a.view(-1).long()] # (n, m, m)
            rows = M_sel[torch.arange(n), y.view(-1).long()] # (n, m)
            loss = torch.sum(outputs * rows)  
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 50)
            self.optimizer.step() 

        self.local_model = self.get_model_params()

        train_stats = local_test
        param_dict = {'norm': torch.norm(self.local_model).item(),
            'max': self.local_model.max().item(),
            'min': self.local_model.min().item()}
        
        return_dict = {'loss': train_stats['loss'] / train_stats['num'],
            'acc': train_stats['acc'] / train_stats['num']}
        return_dict.update(param_dict)

        end_time = time.time()
        stats = {'id': self.cid, 'time': round(end_time - begin_time, 2)}
        stats.update(return_dict)
        return (len(self.train_data), self.local_model), stats

    def model_eval(self, data, local_fair=True, w=None, CFG=False):
        if isinstance(data, DataLoader):
            dataLoader = data
        else: 
            dataLoader = DataLoader(data, batch_size = self.batch_size, shuffle = False)

        self.model.eval()
        test_loss = test_acc = test_num = 0.0
        A_labels = list(self.data_info['A_info'].keys())
        Y_labels = list(self.data_info['Y_info'].keys())
        A_sample_num = {a:0 for a in A_labels}
        local_fair_info = {A_label: {Y_label: {pred_y_label:0 for pred_y_label in Y_labels} for Y_label in Y_labels} for A_label in A_labels}
        
        with torch.no_grad():
            for batch_data in dataLoader:
                if self.sensitive_attr:
                    (x, y, A) = batch_data
                else:
                    (x, y) = batch_data
                if self.gpu:
                    x, y, A = x.to(self.device), y.to(self.device).reshape(-1,1), A.to(self.device).reshape(-1,1)
                
                if CFG == True:
                    self.personal_model.eval()
                    pred_global = F.softmax(self.model(x), dim=1)
                    pred_local = F.softmax(self.personal_model(x), dim=1)
                    pred = pred_global * self.ensemble_weight_global + (1 - self.ensemble_weight_global) * pred_local

                    if self.mission == 'binary':
                        criterion = self.criterion if w is None else nn.BCELoss(weight=w.clone().reshape(-1,1).to(self.device))
                        loss = criterion(pred, y)
                        predicted = (torch.sign(pred - 0.5) + 1) / 2
                        correct = predicted.eq(y).sum().item()
                    elif self.mission == 'multiclass':
                        criterion = self.criterion if w is None else nn.CrossEntropyLoss(weight=w.clone().reshape(-1,1).to(self.device))
                        loss = criterion(pred, y.view(-1).long())
                        _, predicted = torch.max(pred, 1)
                        correct = predicted.eq(y.view(-1).long()).sum().item()
                    
                else:
                    pred = self.model(x)

                    if self.mission == 'binary':
                        criterion = self.criterion if w is None else nn.BCELoss(weight=w.clone().reshape(-1,1).to(self.device))
                        loss = criterion(pred, y)
                        predicted = (torch.sign(pred - 0.5) + 1) / 2
                        correct = predicted.eq(y).sum().item()
                    elif self.mission == 'multiclass':
                        criterion = self.criterion if w is None else nn.CrossEntropyLoss(weight=w.clone().reshape(-1,1).to(self.device))
                        loss = criterion(pred, y.view(-1).long())
                        _, predicted = torch.max(pred, 1)
                        correct = predicted.eq(y.view(-1).long()).sum().item()
                
                batch_size = y.size(0)

                test_loss += loss.item() * y.size(0) # total loss, not average
                test_acc += correct # total acc, not average
                test_num += batch_size 
                for a_temp in A_labels:
                    A_sample_num[a_temp] +=torch.sum((A==a_temp)).cpu().detach()
                    for y_temp in Y_labels:
                        for pred_y in Y_labels:
                            local_fair_info[a_temp][y_temp][pred_y] += torch.sum((predicted.unsqueeze(1) == pred_y) * (y==y_temp) * (A==a_temp)).cpu().detach()

        # total = sum(v for d1 in local_fair_info.values() for d2 in d1.values() for v in d2.values())
        # print(f"compare:{total==test_num}")
        # print(f"local_fair_info:{local_fair_info}")

        test_dict = {'loss': test_loss, 'acc': test_acc, 'num': test_num}
        if local_fair == True:
            if self.options['fairness_type'] == 'groupwise':
                if "DP" in self.fairness_measure and "EO" in self.fairness_measure:
                    assert len(A_labels) == 2
                    assert len(Y_labels) == 2

                    # negative to alignment A1 - A0 item
                    local_DP_org = - (local_fair_info[0][0][1] + local_fair_info[0][1][1]) / A_sample_num[0] + (local_fair_info[1][0][1] + local_fair_info[1][1][1]) / A_sample_num[1]
                    local_DP = torch.abs((local_fair_info[0][0][1] + local_fair_info[0][1][1]) / A_sample_num[0] - (local_fair_info[1][0][1] + local_fair_info[1][1][1]) / A_sample_num[1])
                    
                    # negative to alignment A1 - A0 item
                    local_EO_Y1 = - (local_fair_info[0][1][1] / (local_fair_info[0][1][0] + local_fair_info[0][1][1]) - local_fair_info[1][1][1] / (local_fair_info[1][1][0] + local_fair_info[1][1][1]))
                    local_EO_Y0 = - (local_fair_info[0][0][1] / (local_fair_info[0][0][0] + local_fair_info[0][0][1]) - local_fair_info[1][0][1] / (local_fair_info[1][0][0] + local_fair_info[1][0][1]))
                    local_EO_org = [local_EO_Y1,local_EO_Y0]
                    local_EO = torch.max(local_EO_Y1.abs(),local_EO_Y0.abs())
                # elif self.mission == 'multiclass':

                # DP_disp = torch.sum(self.val_score[self.val_data.A==0] >= post_threshold[0]) / self.N_0_c - torch.sum(self.val_score[self.val_data.A==1] >= post_threshold[1]) / self.N_1_c
                # if val==True:
                test_dict.update({'local_fair_measure':{'DP':float(local_DP),'EO':float(local_EO)}})
                test_dict.update({'local_confusion':local_fair_info})  # C^{a,k}
                test_dict.update({'local_fair_org':{'DP':local_DP_org, 'EO': local_EO_org}})

            if self.options['fairness_type'] == 'subgroup':
                if "DP" in self.fairness_measure and "EOP" in self.fairness_measure:
                    assert len(A_labels) >= 2
                    assert len(Y_labels) > 2
                    assert test_num == sum(A_sample_num.values())
                    total_pred_Y_A = {a:{y: sum([local_fair_info[a][y_temp][y] for y_temp in Y_labels]) for y in Y_labels} for a in A_labels}
                    # total_pred_Y = {y: sum(v for d1 in local_fair_info.values() for d2 in d1.values() for p, v in d2.items() if p == y) for y in Y_labels} #n_y
                    total_pred_Y = {y: sum([local_fair_info[a_temp][y_temp][y] for y_temp in Y_labels for a_temp in A_labels]) for y in Y_labels} #n_y
                    local_DP_org = {a: {y: total_pred_Y_A[a][y] / A_sample_num[a] - total_pred_Y[y]/test_num for y in Y_labels} for a in A_labels }
                    local_DP = torch.max(torch.abs(torch.tensor([[local_DP_org[a][y] for y in Y_labels] for a in A_labels])))

                    total_Y = {y: sum([local_fair_info[a_temp][y][pred_y] for a_temp in A_labels for pred_y in Y_labels]) for y in Y_labels}
                    total_Y_pred_Y = {y: {pred_y: sum([local_fair_info[a][y][pred_y] for a in A_labels])  for pred_y in Y_labels} for y in Y_labels}
                    total_Y_pred_Y_A = {a:{y: {pred_y: local_fair_info[a][y][pred_y] for pred_y in Y_labels} for y in Y_labels } for a in A_labels}
                    total_Y_A = {a:{y: sum([local_fair_info[a][y][pred_y] for pred_y in Y_labels]) for y in Y_labels} for a in A_labels}
                    local_EOP_org = {a: {y: total_Y_pred_Y_A[a][y][y] / total_Y_A[a][y] - total_Y_pred_Y[y][y]/total_Y[y] for y in Y_labels} for a in A_labels }
                    local_EOP = torch.max(torch.abs(torch.tensor([[local_EOP_org[a][y] for y in Y_labels] for a in A_labels])))
                
                test_dict.update({'local_fair_measure':{'multi_DP':float(local_DP),'multi_EOP':float(local_EOP)}})
                test_dict.update({'local_confusion':local_fair_info})  # C^{a,k}
                test_dict.update({'local_fair_org':{'DP':local_DP_org, 'EOP': local_EOP_org}})
                # print(test_dict)
        return test_dict

    def local_eval(self, ensemble=False):
        if ensemble==False:
            train_data_test = self.model_eval(self.train_dataloader)
            test_data_test  = self.model_eval(self.test_dataloader)
        # print(f"train_data_test:{train_data_test}")
        # print(f"test_data_test:{test_data_test}")
        elif ensemble==True:
            train_data_test = self.model_eval(self.train_dataloader, CFG=True)
            test_data_test  = self.model_eval(self.test_dataloader, CFG=True)
        return {'train':train_data_test, 'test':test_data_test}

        