import numpy as np
import time
import torch
import torch.nn as nn
import torch.optim as optim
from fedlearn.models.models import choose_model
from torch.utils.data import DataLoader
from fedlearn.models.FairBatchSampler import FairBatch, CustomDataset
from fedlearn.utils.model_utils import weighted_loss
from fedlearn.utils.model_utils import get_sort_idxs, get_cdf, get_sample_target
import torch.nn.functional as F
import copy
import cvxpy as cp
from fedlearn.models.client import Client

class CGFFedClient(Client):

    def __init__(self, cid, train_data, val_data, test_data, options={}, model=None):

        super().__init__(cid, train_data, val_data, test_data, options, model)
        self.fairness_constraints = options['fairness_constraints']
        self.fair_metric = self.fairness_constraints['fairness_measure']
        self.global_fair_constraint = self.fairness_constraints['global']
        self.local_fair_constraint = self.fairness_constraints['local']
        self.post_local_round_mu = options['post_local_round_mu']

        self.Alabel = [int(a) for a in self.data_info['Alabel']]
        self.Ylabel = [int(y) for y in self.data_info['Ylabel']]
        if self.options['fairness_type'] == 'groupwise':
            if self.fair_metric == 'DP':
                self.global_lamb = torch.tensor([0.01, 0.01], requires_grad=False)
                self.local_mu = torch.tensor([0.01, 0.01], requires_grad=False)
            if self.fair_metric == 'EO':
                self.global_lamb = torch.tensor([[0.1, 0.1],
                                                 [0.1, 0.1]], requires_grad=False) # first row for Y=1, second for Y=0
                # self.global_lamb_2 = torch.tensor([0.1, 0.1], requires_grad=False) # for Y=0
                self.local_mu = torch.tensor([[0.1, 0.1],
                                                 [0.1, 0.1]], requires_grad=False) # first row for Y=1, second for Y=0
                # self.local_mu[1] = torch.tensor([0.1, 0.1], requires_grad=False)
        elif self.options['fairness_type'] == 'subgroup':
            if self.fair_metric == 'DP':
                self.global_lamb_1 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.1
                self.global_lamb_2 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.1
                self.local_mu_1 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.1
                self.local_mu_2 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.1
            elif self.fair_metric == 'EOP':
                self.global_lamb_1 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.01
                self.global_lamb_2 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.01
                self.local_mu_1 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.01
                self.local_mu_2 = torch.ones(len(self.Alabel),len(self.Ylabel), requires_grad=False) * 0.01


        self.ensemble_weight_global = 0.5
        self.train_A_info = self.data_info['train_A_info']

        self.p_k = torch.from_numpy(sum([self.data_info['train_client_A_info'][self.cid][a] for a in self.Alabel]) / sum([self.data_info['train_client_A_info'][c][a] for a in self.Alabel for c in self.data_info['train_client_A_info'].keys()]))

        self.class_num = len(list(self.Ylabel))
        self.train_client_A_info = self.data_info['train_client_A_info'][self.cid]

        self.p_a_k = {a: self.train_client_A_info[a] / self.data_info['train_num'] for a in self.Alabel}
        self.p_k_given_a = {a: self.train_client_A_info[a] / self.train_A_info[a] for a in self.Alabel}

        self.train_A_Y_info = self.data_info['train_A_Y_info']
        self.train_client_A_Y_info = self.data_info['train_client_A_Y_info'][self.cid]
        self.p_a_y_k = {a: {y: torch.from_numpy(self.train_client_A_Y_info[a][y] / self.data_info['train_num']) for y in self.Ylabel} for a in self.Alabel}
        self.p_a_y = {a:{y: torch.from_numpy(self.train_A_Y_info[a][y] / self.data_info['train_num']) for y in self.Ylabel} for a in self.Alabel}
        self.p_y_k = {y:torch.from_numpy(sum([self.train_client_A_Y_info[a][y] for a in self.Alabel]) / self.data_info['train_num']) for y in self.Ylabel}
        self.p_y= {y:torch.from_numpy(sum([self.train_A_Y_info[a][y] for a in self.Alabel]) / self.data_info['train_num']) for y in self.Ylabel}

        print(f'self.p_y_k: {self.p_y_k}')
        print(f'self.p_y: {self.p_y}')

        self.ensemble_params_lr = self.options['ensemble_params_lr']
        self.dual_params_lr = self.options['dual_params_lr']
        self.dual_params_bound = self.options['dual_params_bound']

        # CGFFed model
        self.personal_model = copy.deepcopy(model)
        self.M_a_k = {a : torch.zeros(self.class_num, self.class_num) for a in self.Alabel}

        # optimizer
        if options['local_optimizer'].lower() == 'sgd':
            self.personal_optimizer = optim.SGD(self.personal_model.parameters(), lr=self.local_lr, weight_decay=self.wd)
        elif options['local_optimizer'].lower() == 'adam':
            self.personal_optimizer = optim.Adam(self.personal_model.parameters(), lr=self.local_lr, weight_decay=self.wd)
        elif options['local_optimizer'].lower() == 'adagrad':
            self.personal_optimizer = optim.adagrad(self.personal_model.parameters(), lr=self.local_lr, weight_decay=self.wd)

                
    @staticmethod
    def move_model_to_gpu(model, options):
        if 'gpu' in options and (options['gpu'] is True):
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if 'device' not in options else options['device']
            model.to(device)
            print('>>> Use gpu on self.device {}'.format(device.index))
        else:
            print('>>> Do not use gpu')

    def set_params(self, flat_params):
        '''set model parameters, where input is a flat parameter'''
        self.model.set_params(flat_params)

    def get_model_params(self):
        '''get local flat model parameters, transform torch model parameters into flat tensor'''
        return self.model.get_flat_params()
    
    def get_global_params(self, global_params):
        self.global_params = copy.deepcopy(global_params)

        
    def get_grads(self, mini_batch_data):
        '''get model gradient'''
        x, y = mini_batch_data
        self.model.train()
        if self.gpu:
            x, y = x.to(self.device), y.to(self.device)
        self.optimizer.zero_grad()
        pred = self.model(x)
        loss = self.criterion(pred, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 50)
        flat_grads = self.model.get_flat_grads().cpu().detach()
        # self.optimizer.zero_grad()
        return torch.empty_like(flat_grads).copy_(flat_grads), loss.cpu().detach()

    def get_pred(self):
        self.model.eval()
        dataloader = DataLoader(self.train_data, batch_size = self.batch_size, shuffle = False)
        pred_score = predicted = torch.tensor([])
        with torch.no_grad():
            for i, (x, y, a) in enumerate(dataloader):
                if self.gpu:
                    x, y = x.to(self.device), y.to(self.device)
                self.optimizer.zero_grad()
                pred_score_batch = self.model(x).detach().cpu()
                if self.mission == 'binary':
                    predicted_batch = ((torch.sign(pred_score_batch - 0.5) + 1) / 2)
                elif self.mission == 'multiclass':
                    _, predicted_batch = torch.max(pred_score_batch, 1)
                pred_score = torch.cat([pred_score, pred_score_batch.squeeze()])
                predicted = torch.cat([predicted, predicted_batch.squeeze()])
        return pred_score.reshape(-1,1).clone(), predicted.reshape(-1,1).clone(), self.train_data.A
    
    def local_train(self):

        begin_time = time.time()

        for _ in range(self.num_local_round):
            self.model.train()
            (x, y, a) = self.get_next_train_batch()
            if self.gpu:
                x, y = torch.squeeze(x.to(self.device)), y.to(self.device).reshape(-1,1)
            self.optimizer.zero_grad()
            pred = self.model(x)
            loss = self.criterion(pred, y.view(-1).long())
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 60)
            self.optimizer.step() 
        # print(f"\n learnig rate: {self.optimizer.param_groups[0]['lr']}")
        # self.scheduler.step()

        self.local_model = self.get_model_params()

        train_stats = self.model_eval(self.train_dataloader)
        param_dict = {'norm': torch.norm(self.local_model).item(),
            'max': self.local_model.max().item(),
            'min': self.local_model.min().item()}
        
        return_dict = {'loss': train_stats['loss'] / train_stats['num'],
            'acc': train_stats['acc'] / train_stats['num']}
        return_dict.update(param_dict)

        end_time = time.time()
        stats = {'id': self.cid, 'time': round(end_time - begin_time, 2)}
        stats.update(return_dict)
        return (len(self.train_data), self.local_model), stats
    
    def get_cal_Matrix(self, lamb, mu):
        if self.options['fairness_type'] == 'groupwise':
            if self.fair_metric == 'DP':
                assert len(lamb) == 2
                assert len(mu) == 2
                M_a_k = {a : torch.zeros(self.class_num, self.class_num) for a in self.Alabel}

                # base = torch.zeros(self.class_num, self.class_num, dtype=lamb.dtype, device=lamb.device); base[0,1] = base[1,1] = 1

                # D_g_k = {a: (2*a - 1) * self.p_k_given_a[a] * base for a in self.Alabel}
                # D_l_k = {a: (2*a - 1) * base                    for a in self.Alabel}

                D_g_k = {a: (2*a-1)*torch.tensor([[0, self.p_k_given_a[a]],
                                                  [0, self.p_k_given_a[a]]]).clone() for a in self.Alabel}
                D_l_k = {a: (2*a-1)*torch.tensor([[0, 1],
                                                  [0, 1]]).clone() for a in self.Alabel}
                for a in self.Alabel:
                    M_a_k[a] = torch.eye(self.class_num, self.class_num) - 1 / float(self.p_a_k[a]) * (lamb[0] - lamb[1]) *  D_g_k[a] - 1 / float(self.p_a_k[a]) * (mu[0] - mu[1]) *  D_l_k[a]

                return M_a_k
            if self.fair_metric == 'EO':
                assert lamb.shape == (2, 2)
                assert mu.shape == (2, 2)
                M_a_k = {a : torch.zeros(self.class_num, self.class_num) for a in self.Alabel}
                # for Y = 1
                D_g_k_1 = {a: (2*a-1)*torch.tensor([[0, 0                                ],
                                                    [0, self.p_a_k[a] / self.p_a_y[a][1] ]]).clone() for a in self.Alabel}
                # for Y = 0
                D_g_k_2 = {a: (2*a-1)*torch.tensor([[0, self.p_a_k[a] / self.p_a_y[a][0] ],
                                                    [0, 0                                ]]).clone() for a in self.Alabel}
                
                # for Y = 1
                D_l_k_1 = {a: (2*a-1)*torch.tensor([[0, 0                                ],
                                                    [0, self.p_a_k[a] / self.p_a_y_k[a][1] ]]).clone() for a in self.Alabel}
                # for Y = 0
                D_l_k_2 = {a: (2*a-1)*torch.tensor([[0, self.p_a_k[a] / self.p_a_y_k[a][0] ],
                                                    [0, 0                                ]]).clone() for a in self.Alabel}
                for a in self.Alabel:
                    M_a_k[a] = ( torch.eye(self.class_num, self.class_num) - 1 / float(self.p_a_k[a]) * (lamb[0][0] - lamb[0][1]) *  D_g_k_1[a]
                                                                           - 1 / float(self.p_a_k[a]) * (lamb[1][0] - lamb[1][1]) *  D_g_k_2[a]
                                                                           - 1 / float(self.p_a_k[a]) * (mu[0][0] - mu[0][1]) *  D_l_k_1[a] 
                                                                           - 1 / float(self.p_a_k[a]) * (mu[1][0] - mu[1][1]) *  D_l_k_2[a] )
        
                return M_a_k
            
        if self.options['fairness_type'] == 'subgroup':
            assert lamb[0].shape == (len(self.Alabel), len(self.Ylabel)) # lamb[0] = lamb_1, lamb[1] = lamb_2
            assert mu[0].shape == (len(self.Alabel), len(self.Ylabel)) # mu[0] = mu_1, mu[1] = mu_2
            if self.fair_metric == 'DP':
                M_a_k = {a : torch.zeros(self.class_num, self.class_num) for a in self.Alabel}

                D_g_k = {a_prime: {y: {aa: torch.zeros(self.class_num, self.class_num).clone() for aa in self.Alabel} for y in self.Ylabel} for a_prime in self.Alabel}
                for a_prime in self.Alabel:
                    for y in self.Ylabel:
                        for aa in self.Alabel:
                            D_g_k[a_prime][y][aa][:,y] = self.p_k_given_a[aa] - self.p_a_k[aa] if aa == a_prime else - self.p_a_k[aa]

                D_l_k = {a_prime: {y: {aa: torch.zeros(self.class_num, self.class_num).clone() for aa in self.Alabel} for y in self.Ylabel} for a_prime in self.Alabel}

                for a_prime in self.Alabel:
                    for y in self.Ylabel:
                        for aa in self.Alabel:
                            D_l_k[a_prime][y][aa][:,y] = 1 - self.p_k_given_a[aa] if aa == a_prime else - self.p_k_given_a[aa]

                for a in self.Alabel:
                    M_a_k[a] = ( torch.eye(self.class_num, self.class_num) - 1 / float(self.p_a_k[a]) * sum((lamb[0][a_prime, y] - lamb[1][a_prime, y]) * D_g_k[a_prime][y][a] for a_prime in self.Alabel for y in self.Ylabel)
                                                                           - 1 / float(self.p_a_k[a]) * sum((mu[0][a_prime, y]   - mu[1][a_prime, y])   * D_l_k[a_prime][y][a] for a_prime in self.Alabel for y in self.Ylabel))
            
                return M_a_k
            
            elif self.fair_metric == 'EOP':
                M_a_k = {a : torch.zeros(self.class_num, self.class_num) for a in self.Alabel}

                D_g_k = {a_prime: {y: {aa: torch.zeros(self.class_num, self.class_num).clone() for aa in self.Alabel} for y in self.Ylabel} for a_prime in self.Alabel}
                for a_prime in self.Alabel:
                    for y in self.Ylabel:
                        for aa in self.Alabel:
                            D_g_k[a_prime][y][aa][y,y] = self.p_a_k[aa]/self.p_a_y[aa][y] - self.p_a_k[aa]/self.p_y[y] if aa == a_prime else - self.p_a_k[aa]/self.p_y[y]
                
                D_l_k = {a_prime: {y: {aa: torch.zeros(self.class_num, self.class_num).clone() for aa in self.Alabel} for y in self.Ylabel} for a_prime in self.Alabel}

                for a_prime in self.Alabel:
                    for y in self.Ylabel:
                        for aa in self.Alabel:
                            D_l_k[a_prime][y][aa][y,y] = self.p_a_k[aa]/self.p_a_y_k[aa][y] - self.p_a_k[aa]/self.p_y_k[y] if aa == a_prime else - self.p_a_k[aa]/self.p_y_k[y]
                
                for a in self.Alabel:
                    M_a_k[a] = ( torch.eye(self.class_num, self.class_num) - 1 / float(self.p_a_k[a]) * sum((lamb[0][a_prime, y] - lamb[1][a_prime, y]) * D_g_k[a_prime][y][a] for a_prime in self.Alabel for y in self.Ylabel)
                                                                           - 1 / float(self.p_a_k[a]) * sum((mu[0][a_prime, y]   - mu[1][a_prime, y])   * D_l_k[a_prime][y][a] for a_prime in self.Alabel for y in self.Ylabel))
                    
                return M_a_k
                
            
    def get_calibrated_loss(self, model, M):
        dataloader = DataLoader(self.train_data, batch_size = self.batch_size, shuffle = False)
        total_loss = 0
        sample_num = 0
        M_tensor = torch.stack([M[aa] for aa in self.Alabel], dim=0).to(self.device) # (a, m, m)
        model.eval()
        with torch.no_grad():
            for i, (x, y, a) in enumerate(dataloader):
                if self.gpu:
                    x, y, a = x.to(self.device), y.to(self.device), a.to(self.device)
                if self.mission == 'binary':
                    pred = model(x)
                elif self.mission == 'multiclass':
                    pred = model(x)
                    outputs = - F.log_softmax(pred, dim=1)
                n, m = outputs.shape
                M_sel = M_tensor[a.view(-1).long()] # (n, m, m)
                rows = M_sel[torch.arange(n), y.view(-1).long()] # (n, m)
                batch_loss = torch.sum(outputs * rows, dim=1)  # (n,)
                total_loss += batch_loss.sum().detach().cpu().item()
                sample_num += n
        return {'average_loss': total_loss/sample_num, 'total_loss':total_loss,'sample_num':sample_num}
    
    def local_obj_H(self, lamb, mu, beta=0.1):
        if self.options['fairness_type'] == 'groupwise':
            self.M_a_k = self.get_cal_Matrix(lamb, mu)

            M_stack = torch.stack([ self.M_a_k[a].t() for a in self.Alabel ], dim=0) # (A,m,m)
            A = torch.tensor(self.val_data.A, device=M_stack.device) 
            M_i_T = M_stack[A.view(-1).long()]   # (n_k, m, m)

            eta = self.val_score.unsqueeze(-1).double()     # -> (n_k, m, 1)
            logits = torch.bmm(M_i_T, eta).squeeze(-1)  # -> (n_k, m)
            term_1 = logits.max(dim=1).values.mean()    # scalar

        elif self.options['fairness_type'] == 'subgroup':
            self.M_a_k = self.get_cal_Matrix(lamb, mu)

            M_stack = torch.stack([ self.M_a_k[a].t() for a in self.Alabel ], dim=0) # (A,m,m)
            A = torch.tensor(self.val_data.A, device=M_stack.device) 
            M_i_T = M_stack[A.view(-1).long()]   # (n_k, m, m)

            eta = self.val_score.unsqueeze(-1).double()     # -> (n_k, m, 1)
            logits = torch.bmm(M_i_T, eta).squeeze(-1)  # -> (n_k, m)
            softmax_weights = F.softmax(logits / beta, dim=1)
            term_1 = torch.sum(softmax_weights * logits, dim=1).mean()

        loss = term_1 +  self.global_fair_constraint * torch.norm(lamb[0] + lamb[1], p=1)  + self.local_fair_constraint * torch.norm(mu[0] + mu[1], p=1) /  self.p_k
        return loss
    
    def FFACT_post_train(self):
        begin_time = time.time()
        # already set the parameter global_lamb, model parameter
        self.model.eval()
        if self.options['fairness_type'] == 'groupwise':
            self.global_lamb.requires_grad_(False)
            self.local_mu_optim = self.local_mu.clone().detach()
            self.local_mu_optim.requires_grad_(True)
            # self.local_mu_optim.grad.zero_()
        
            for _ in range(self.post_local_round_mu):
                loss = self.local_obj_H(self.global_lamb,self.local_mu_optim)
                loss.backward()
                with torch.no_grad():
                    self.local_mu_optim -= self.dual_params_lr * self.local_mu_optim.grad
                    self.local_mu_optim = torch.clamp(self.local_mu_optim, min=0.0, max=self.dual_params_bound).requires_grad_(True)
            self.local_mu = self.local_mu_optim.detach().clone().requires_grad_(False)

            self.global_lamb_optim = self.global_lamb.clone().detach()
            self.global_lamb_optim.requires_grad_(True)
            loss = self.local_obj_H( self.global_lamb_optim, self.local_mu)
            loss.backward()
            with torch.no_grad(): 
                self.global_lamb_optim -= self.dual_params_lr * 5 * self.global_lamb_optim.grad

            return self.global_lamb_optim.detach().clone(), self.val_data.X.shape[0]
        elif self.options['fairness_type'] == 'subgroup':
            self.global_lamb_1.requires_grad_(False)
            self.global_lamb_2.requires_grad_(False)
            self.local_mu_optim_1 = self.local_mu_1.clone().detach()
            self.local_mu_optim_1.requires_grad_(True)
            self.local_mu_optim_2 = self.local_mu_2.clone().detach()
            self.local_mu_optim_2.requires_grad_(True)
            # self.local_mu_optim.grad.zero_()
        
            for _ in range(self.post_local_round_mu):
                loss = self.local_obj_H([self.global_lamb_1,self.global_lamb_2],[self.local_mu_optim_1,self.local_mu_optim_2])
                loss.backward()
                with torch.no_grad():
                    self.local_mu_optim_1 -= self.dual_params_lr * self.local_mu_optim_1.grad
                    self.local_mu_optim_1 = torch.clamp(self.local_mu_optim_1, min=0.0, max=self.dual_params_bound).requires_grad_(True)
                    self.local_mu_optim_2 -= self.dual_params_lr * self.local_mu_optim_2.grad
                    self.local_mu_optim_2 = torch.clamp(self.local_mu_optim_2, min=0.0, max=self.dual_params_bound).requires_grad_(True)
            self.local_mu_1 = self.local_mu_optim_1.detach().clone().requires_grad_(False)
            self.local_mu_2 = self.local_mu_optim_2.detach().clone().requires_grad_(False)

            self.global_lamb_optim_1 = self.global_lamb_1.clone().detach()
            self.global_lamb_optim_1.requires_grad_(True)
            self.global_lamb_optim_2 = self.global_lamb_2.clone().detach()
            self.global_lamb_optim_2.requires_grad_(True)
            loss = self.local_obj_H( [self.global_lamb_optim_1, self.global_lamb_optim_2],[self.local_mu_1, self.local_mu_2])
            loss.backward()
            with torch.no_grad(): 
                self.global_lamb_optim_1 -= self.dual_params_lr  * 5 * self.global_lamb_optim_1.grad
                self.global_lamb_optim_2 -= self.dual_params_lr  * 5 * self.global_lamb_optim_2.grad

            return self.global_lamb_optim_1.detach().clone(), self.global_lamb_optim_2.detach().clone(), self.val_data.X.shape[0]

    def model_eval(self, data, local_fair=True, w=None, CFG=False):
        if isinstance(data, DataLoader):
            dataLoader = data
        else: 
            dataLoader = DataLoader(data, batch_size = self.batch_size, shuffle = False)

        self.model.eval()
        test_loss = test_acc = test_num = 0.0
        A_labels = list(self.data_info['A_info'].keys())
        Y_labels = list(self.data_info['Y_info'].keys())
        A_sample_num = {a:0 for a in A_labels}
        local_fair_info = {A_label: {Y_label: {pred_y_label:0 for pred_y_label in Y_labels} for Y_label in Y_labels} for A_label in A_labels}
        
        with torch.no_grad():
            for batch_data in dataLoader:
                if self.sensitive_attr:
                    (x, y, A) = batch_data
                else:
                    (x, y) = batch_data
                if self.gpu:
                    x, y, A = x.to(self.device), y.to(self.device).reshape(-1,1), A.to(self.device).reshape(-1,1)
                
                if CFG == True:
                    self.personal_model.eval()
                    pred_global = F.softmax(self.model(x), dim=1)
                    pred_local = F.softmax(self.personal_model(x), dim=1)
                    pred = pred_global * self.ensemble_weight_global + (1 - self.ensemble_weight_global) * pred_local

                    if self.mission == 'binary':
                        criterion = self.criterion if w is None else nn.BCELoss(weight=w.clone().reshape(-1,1).to(self.device))
                        loss = criterion(pred, y)
                        predicted = (torch.sign(pred - 0.5) + 1) / 2
                        correct = predicted.eq(y).sum().item()
                    elif self.mission == 'multiclass':
                        criterion = self.criterion if w is None else nn.CrossEntropyLoss(weight=w.clone().reshape(-1,1).to(self.device))
                        loss = criterion(pred, y.view(-1).long())
                        _, predicted = torch.max(pred, 1)
                        correct = predicted.eq(y.view(-1).long()).sum().item()
                    
                else:
                    pred = self.model(x)

                    if self.mission == 'binary':
                        criterion = self.criterion if w is None else nn.BCELoss(weight=w.clone().reshape(-1,1).to(self.device))
                        loss = criterion(pred, y)
                        predicted = (torch.sign(pred - 0.5) + 1) / 2
                        correct = predicted.eq(y).sum().item()
                    elif self.mission == 'multiclass':
                        criterion = self.criterion if w is None else nn.CrossEntropyLoss(weight=w.clone().reshape(-1,1).to(self.device))
                        loss = criterion(pred, y.view(-1).long())
                        _, predicted = torch.max(pred, 1)
                        correct = predicted.eq(y.view(-1).long()).sum().item()
                
                batch_size = y.size(0)

                test_loss += loss.item() * y.size(0) # total loss, not average
                test_acc += correct # total acc, not average
                test_num += batch_size 
                for a_temp in A_labels:
                    A_sample_num[a_temp] +=torch.sum((A==a_temp)).cpu().detach()
                    for y_temp in Y_labels:
                        for pred_y in Y_labels:
                            local_fair_info[a_temp][y_temp][pred_y] += torch.sum((predicted.unsqueeze(1) == pred_y) * (y==y_temp) * (A==a_temp)).cpu().detach()

        # total = sum(v for d1 in local_fair_info.values() for d2 in d1.values() for v in d2.values())
        # print(f"compare:{total==test_num}")
        # print(f"local_fair_info:{local_fair_info}")

        test_dict = {'loss': test_loss, 'acc': test_acc, 'num': test_num}
        if local_fair == True:
            if self.options['fairness_type'] == 'groupwise':
                if "DP" in self.fairness_measure and "EO" in self.fairness_measure:
                    assert len(A_labels) == 2
                    assert len(Y_labels) == 2

                    # negative to alignment A1 - A0 item
                    local_DP_org = - (local_fair_info[0][0][1] + local_fair_info[0][1][1]) / A_sample_num[0] + (local_fair_info[1][0][1] + local_fair_info[1][1][1]) / A_sample_num[1]
                    local_DP = torch.abs((local_fair_info[0][0][1] + local_fair_info[0][1][1]) / A_sample_num[0] - (local_fair_info[1][0][1] + local_fair_info[1][1][1]) / A_sample_num[1])
                    
                    # negative to alignment A1 - A0 item
                    local_EO_Y1 = - (local_fair_info[0][1][1] / (local_fair_info[0][1][0] + local_fair_info[0][1][1]) - local_fair_info[1][1][1] / (local_fair_info[1][1][0] + local_fair_info[1][1][1]))
                    local_EO_Y0 = - (local_fair_info[0][0][1] / (local_fair_info[0][0][0] + local_fair_info[0][0][1]) - local_fair_info[1][0][1] / (local_fair_info[1][0][0] + local_fair_info[1][0][1]))
                    local_EO_org = [local_EO_Y0,local_EO_Y1]
                    local_EO = torch.max(local_EO_Y1.abs(),local_EO_Y0.abs())
                # elif self.mission == 'multiclass':

                # DP_disp = torch.sum(self.val_score[self.val_data.A==0] >= post_threshold[0]) / self.N_0_c - torch.sum(self.val_score[self.val_data.A==1] >= post_threshold[1]) / self.N_1_c
                # if val==True:
                test_dict.update({'local_fair_measure':{'DP':float(local_DP),'EO':float(local_EO)}})
                test_dict.update({'local_confusion':local_fair_info})  # C^{a,k}
                test_dict.update({'local_fair_org':{'DP':local_DP_org, 'EO': local_EO_org}})

            if self.options['fairness_type'] == 'subgroup':
                if "DP" in self.fairness_measure and "EOP" in self.fairness_measure:
                    assert len(A_labels) >= 2
                    assert len(Y_labels) > 2
                    assert test_num == sum(A_sample_num.values())
                    total_pred_Y_A = {a:{y: sum([local_fair_info[a][y_temp][y] for y_temp in Y_labels]) for y in Y_labels} for a in A_labels}
                    # total_pred_Y = {y: sum(v for d1 in local_fair_info.values() for d2 in d1.values() for p, v in d2.items() if p == y) for y in Y_labels} #n_y
                    total_pred_Y = {y: sum([local_fair_info[a_temp][y_temp][y] for y_temp in Y_labels for a_temp in A_labels]) for y in Y_labels} #n_y
                    local_DP_org = {a: {y: total_pred_Y_A[a][y] / A_sample_num[a] - total_pred_Y[y]/test_num for y in Y_labels} for a in A_labels }
                    local_DP = torch.max(torch.abs(torch.tensor([[local_DP_org[a][y] for y in Y_labels] for a in A_labels])))

                    total_Y = {y: sum([local_fair_info[a_temp][y][pred_y] for a_temp in A_labels for pred_y in Y_labels]) for y in Y_labels}
                    total_Y_pred_Y = {y: {pred_y: sum([local_fair_info[a][y][pred_y] for a in A_labels])  for pred_y in Y_labels} for y in Y_labels}
                    total_Y_pred_Y_A = {a:{y: {pred_y: local_fair_info[a][y][pred_y] for pred_y in Y_labels} for y in Y_labels } for a in A_labels}
                    total_Y_A = {a:{y: sum([local_fair_info[a][y][pred_y] for pred_y in Y_labels]) for y in Y_labels} for a in A_labels}
                    local_EOP_org = {a: {y: total_Y_pred_Y_A[a][y][y] / total_Y_A[a][y] - total_Y_pred_Y[y][y]/total_Y[y] for y in Y_labels} for a in A_labels }
                    local_EOP = torch.max(torch.abs(torch.tensor([[local_EOP_org[a][y] for y in Y_labels] for a in A_labels])))
                
                test_dict.update({'local_fair_measure':{'multi_DP':float(local_DP),'multi_EOP':float(local_EOP)}})
                test_dict.update({'local_confusion':local_fair_info})  # C^{a,k}
                test_dict.update({'local_fair_org':{'DP':local_DP_org, 'EOP': local_EOP_org}})
                # print(test_dict)
        return test_dict
    
    def get_eval_dict(self, test_dict, local_fair_info, A_sample_num):
        A_labels = self.Alabel
        Y_labels = self.Ylabel
        test_num = sum(A_sample_num.values())
        if self.options['fairness_type'] == 'groupwise':
            if "DP" in self.fairness_measure and "EO" in self.fairness_measure:
                assert len(A_labels) == 2
                assert len(Y_labels) == 2

                # negative to alignment A1 - A0 item
                local_DP_org = - (local_fair_info[0][0][1] + local_fair_info[0][1][1]) / A_sample_num[0] + (local_fair_info[1][0][1] + local_fair_info[1][1][1]) / A_sample_num[1]
                local_DP = torch.abs((local_fair_info[0][0][1] + local_fair_info[0][1][1]) / A_sample_num[0] - (local_fair_info[1][0][1] + local_fair_info[1][1][1]) / A_sample_num[1])
                
                # negative to alignment A1 - A0 item
                local_EO_Y1 = - (local_fair_info[0][1][1] / (local_fair_info[0][1][0] + local_fair_info[0][1][1]) - local_fair_info[1][1][1] / (local_fair_info[1][1][0] + local_fair_info[1][1][1]))
                local_EO_Y0 = - (local_fair_info[0][0][1] / (local_fair_info[0][0][0] + local_fair_info[0][0][1]) - local_fair_info[1][0][1] / (local_fair_info[1][0][0] + local_fair_info[1][0][1]))
                local_EO_org = [local_EO_Y1,local_EO_Y0]
                local_EO = torch.max(local_EO_Y1.abs(),local_EO_Y0.abs())
            # elif self.mission == 'multiclass':

            # DP_disp = torch.sum(self.val_score[self.val_data.A==0] >= post_threshold[0]) / self.N_0_c - torch.sum(self.val_score[self.val_data.A==1] >= post_threshold[1]) / self.N_1_c
            # if val==True:
            test_dict.update({'local_fair_measure':{'DP':float(local_DP),'EO':float(local_EO)}})
            test_dict.update({'local_confusion':local_fair_info})  # C^{a,k}
            test_dict.update({'local_fair_org':{'DP':local_DP_org, 'EO': local_EO_org}})

        if self.options['fairness_type'] == 'subgroup':
            if "DP" in self.fairness_measure and "EOP" in self.fairness_measure:
                assert len(A_labels) >= 2
                assert len(Y_labels) > 2
                assert test_num == sum(A_sample_num.values())
                total_pred_Y_A = {a:{y: sum([local_fair_info[a][y_temp][y] for y_temp in Y_labels]) for y in Y_labels} for a in A_labels}
                # total_pred_Y = {y: sum(v for d1 in local_fair_info.values() for d2 in d1.values() for p, v in d2.items() if p == y) for y in Y_labels} #n_y
                total_pred_Y = {y: sum([local_fair_info[a_temp][y_temp][y] for y_temp in Y_labels for a_temp in A_labels]) for y in Y_labels} #n_y
                local_DP_org = {a: {y: total_pred_Y_A[a][y] / A_sample_num[a] - total_pred_Y[y]/test_num for y in Y_labels} for a in A_labels }
                local_DP = torch.max(torch.abs(torch.tensor([[local_DP_org[a][y] for y in Y_labels] for a in A_labels])))

                total_Y = {y: sum([local_fair_info[a_temp][y][pred_y] for a_temp in A_labels for pred_y in Y_labels]) for y in Y_labels}
                total_Y_pred_Y = {y: {pred_y: sum([local_fair_info[a][y][pred_y] for a in A_labels])  for pred_y in Y_labels} for y in Y_labels}
                total_Y_pred_Y_A = {a:{y: {pred_y: local_fair_info[a][y][pred_y] for pred_y in Y_labels} for y in Y_labels } for a in A_labels}
                total_Y_A = {a:{y: sum([local_fair_info[a][y][pred_y] for pred_y in Y_labels]) for y in Y_labels} for a in A_labels}
                local_EOP_org = {a: {y: total_Y_pred_Y_A[a][y][y] / total_Y_A[a][y] - total_Y_pred_Y[y][y]/total_Y[y] for y in Y_labels} for a in A_labels }
                local_EOP = torch.max(torch.abs(torch.tensor([[local_EOP_org[a][y] for y in Y_labels] for a in A_labels])))
            
            test_dict.update({'local_fair_measure':{'multi_DP':float(local_DP),'multi_EOP':float(local_EOP)}})
            test_dict.update({'local_confusion':local_fair_info})  # C^{a,k}
            test_dict.update({'local_fair_org':{'DP':local_DP_org, 'EOP': local_EOP_org}})
            # print(test_dict)
        return test_dict
    
    def local_post_eval(self):
        self.model.eval()
        if self.options['fairness_type'] == 'groupwise':
            self.global_lamb.requires_grad_(False)
            self.local_mu_optim = self.local_mu.clone().detach()
            self.local_mu_optim.requires_grad_(True)

            for _ in range(self.post_local_round_mu):
                loss = self.local_obj_H(self.global_lamb, self.local_mu_optim)
                loss.backward()
                with torch.no_grad():
                    self.local_mu_optim -= self.dual_params_lr * self.local_mu_optim.grad
                    self.local_mu_optim = torch.clamp(self.local_mu_optim, min=0.0, max=self.dual_params_bound).requires_grad_(True)
            self.local_mu = self.local_mu_optim.detach().clone().requires_grad_(False)

        elif self.options['fairness_type'] == 'subgroup':
            self.global_lamb_1.requires_grad_(False)
            self.global_lamb_2.requires_grad_(False)
            self.local_mu_optim_1 = self.local_mu_1.clone().detach()
            self.local_mu_optim_1.requires_grad_(True)
            self.local_mu_optim_2 = self.local_mu_2.clone().detach()
            self.local_mu_optim_2.requires_grad_(True)
            # self.local_mu_optim.grad.zero_()
        
            for _ in range(self.post_local_round_mu):
                loss = self.local_obj_H([self.global_lamb_1,self.global_lamb_2],[self.local_mu_optim_1,self.local_mu_optim_2])
                loss.backward()
                with torch.no_grad():
                    self.local_mu_optim_1 -= self.dual_params_lr * self.local_mu_optim_1.grad
                    self.local_mu_optim_1 = torch.clamp(self.local_mu_optim_1, min=0.0, max=self.dual_params_bound).requires_grad_(True)
                    self.local_mu_optim_2 -= self.dual_params_lr * self.local_mu_optim_2.grad
                    self.local_mu_optim_2 = torch.clamp(self.local_mu_optim_2, min=0.0, max=self.dual_params_bound).requires_grad_(True)
            self.local_mu_1 = self.local_mu_optim_1.detach().clone().requires_grad_(False)
            self.local_mu_2 = self.local_mu_optim_2.detach().clone().requires_grad_(False)
        
        if self.options['fairness_type'] == 'groupwise':
            self.M_a_k = self.get_cal_Matrix(self.global_lamb, self.local_mu)
        elif self.options['fairness_type'] == 'subgroup':
            self.M_a_k = self.get_cal_Matrix([self.global_lamb_1,self.global_lamb_2], [self.local_mu_1,self.local_mu_2])

        # val data eval
        M_stack = torch.stack([ self.M_a_k[a].t() for a in self.Alabel ], dim=0) # (A,m,m)
        A = torch.tensor(self.val_data.A, device=M_stack.device) 
        M_i_T = M_stack[A.view(-1).long()]  # (n_k, m, m)

        eta = self.val_score.unsqueeze(-1).double()     # -> (n_k, m, 1)
        pred = torch.bmm(M_i_T, eta).squeeze(-1)  # -> (n_k, m)
        _, val_predicted = torch.max(pred, 1)
        correct = val_predicted.eq(torch.tensor(self.val_data.Y).view(-1).long()).sum().item()
        val_acc = correct # total acc, not average
        val_num = len(self.val_data)
        val_dict = { 'acc': val_acc, 'num': val_num}

        val_A_sample_num = {a:0 for a in self.Alabel}
        val_local_fair_info = {A_label: {Y_label: {pred_y_label:0 for pred_y_label in self.Ylabel} for Y_label in self.Ylabel} for A_label in self.Alabel}
    
        for a_temp in self.Alabel:
            val_A_sample_num[a_temp] +=torch.sum(torch.tensor(self.val_data.A==a_temp)).cpu().detach()
            for y_temp in self.Ylabel:
                for pred_y in self.Ylabel:
                    val_local_fair_info[a_temp][y_temp][pred_y] += torch.sum(torch.tensor((val_predicted.unsqueeze(1) == pred_y) * (self.val_data.Y==y_temp) * (self.val_data.A==a_temp))).clone().detach().cpu()
        
        val_dict = self.get_eval_dict(val_dict, val_local_fair_info, val_A_sample_num)

        # test data eval
        A = torch.tensor(self.test_data.A, device=M_stack.device) 
        M_i_T = M_stack[A.view(-1).long()]  # (n_k, m, m)

        eta = self.test_score.unsqueeze(-1).double()     # -> (n_k, m, 1)
        pred = torch.bmm(M_i_T, eta).squeeze(-1)  # -> (n_k, m)
        _, test_predicted = torch.max(pred, 1)
        correct = test_predicted.eq(torch.tensor(self.test_data.Y).view(-1).long()).sum().item()
        test_acc = correct # total acc, not average
        test_num = len(self.test_data)
        test_dict = { 'acc': test_acc, 'num': test_num}

        test_A_sample_num = {a:0 for a in self.Alabel}
        test_local_fair_info = {A_label: {Y_label: {pred_y_label:0 for pred_y_label in self.Ylabel} for Y_label in self.Ylabel} for A_label in self.Alabel}
    
        for a_temp in self.Alabel:
            test_A_sample_num[a_temp] +=torch.sum(torch.tensor(self.test_data.A==a_temp)).cpu().detach()
            for y_temp in self.Ylabel:
                for pred_y in self.Ylabel:
                    test_local_fair_info[a_temp][y_temp][pred_y] += torch.sum(torch.tensor((test_predicted.unsqueeze(1) == pred_y) * (self.test_data.Y==y_temp) * (self.test_data.A==a_temp))).clone().detach().cpu()
        
        test_dict = self.get_eval_dict(test_dict, test_local_fair_info, test_A_sample_num)

        return { 'val': val_dict,'test':test_dict}


    def get_val_test_score(self):
        # val score
        val_dataloader = DataLoader(self.val_data, batch_size = self.batch_size, shuffle = False)
        val_dataloader_iter = enumerate(val_dataloader)

        val_score = torch.zeros_like(F.one_hot(torch.tensor(self.val_data.Y, dtype=torch.long).view(-1), num_classes=self.class_num),requires_grad=False).float()
        batch_start = 0

        self.model.eval()
        with torch.no_grad(): 
            for i,  (x, y, a) in val_dataloader_iter:
                if self.gpu:
                    x, y = torch.squeeze(x.to(self.device)), y.to(self.device)
                pred_score_batch = F.softmax(self.model(x).detach().clone().cpu(),dim=1)
                batch_size = pred_score_batch.size(0)
                val_score[batch_start:(batch_start + batch_size),:] = pred_score_batch
                batch_start += batch_size
        
        assert len(val_score) == len(self.val_data)

        self.val_score = val_score.detach().clone().cpu()

        # test score
        test_dataloader = DataLoader(self.test_data, batch_size = self.batch_size, shuffle = False)
        test_dataloader_iter = enumerate(test_dataloader)

        test_score = torch.zeros_like(F.one_hot(torch.tensor(self.test_data.Y, dtype=torch.long).view(-1), num_classes=self.class_num),requires_grad=False).float()
        batch_start = 0

        self.model.eval()
        with torch.no_grad(): 
            for i,  (x, y, a) in test_dataloader_iter:
                if self.gpu:
                    x, y = torch.squeeze(x.to(self.device)), y.to(self.device)
                pred_score_batch = F.softmax(self.model(x).detach().clone().cpu(),dim=1)
                batch_size = pred_score_batch.size(0)
                test_score[batch_start:(batch_start + batch_size),:] = pred_score_batch
                batch_start += batch_size
        
        assert len(test_score) == len(self.test_data)

        self.test_score = test_score.detach().clone().cpu()

        return self.val_score

    def local_eval(self, ensemble=False):
        if ensemble==False:
            train_data_test = self.model_eval(self.train_dataloader)
            test_data_test  = self.model_eval(self.test_dataloader)
        # print(f"train_data_test:{train_data_test}")
        # print(f"test_data_test:{test_data_test}")
        elif ensemble==True:
            train_data_test = self.model_eval(self.train_dataloader, CFG=True)
            test_data_test  = self.model_eval(self.test_dataloader, CFG=True)
        return {'train':train_data_test, 'test':test_data_test}

        