import numpy as np
import copy
import os
import functools
import sys

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.functional import cross_entropy

from utils.torch_utils import numpy_to_torch, torch_to_numpy

from copy import deepcopy
from sklearn.preprocessing import normalize
import argparse
from utils.model_utils import read_data, batch_data
from PIL import Image

IMAGES_DIR = "data/celeba/data/raw/img_align_celeba"
IMAGE_SIZE = 84

class CNN(torch.nn.Module):
    def __init__(self, output_dim=2):
        super(CNN, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = torch.nn.Conv2d(32, 32, 3, padding=1)
        self.conv3 = torch.nn.Conv2d(32, 32, 3, padding=1)
        self.conv4 = torch.nn.Conv2d(32, 32, 3, padding=1)
        self.bn1 = torch.nn.BatchNorm2d(32)
        self.bn2 = torch.nn.BatchNorm2d(32)
        self.bn3 = torch.nn.BatchNorm2d(32)
        self.bn4 = torch.nn.BatchNorm2d(32)
        self.linear = torch.nn.Linear(32 * 5 * 5, output_dim)

    
    def trainable_parameters(self):
        return [p for p in self.parameters() if p.requires_grad]

    def forward(self, x):
        # print(x,x.shape)
        # x = F.relu(F.max_pool2d(torch.nn.BatchNorm2d(32)(self.conv1(x.float()),2))
        # print(x.shape)
        x = F.max_pool2d(self.bn1(self.conv1(x.float())),2)
        # print(x.shape)
        x = F.max_pool2d(self.bn2(self.conv2(x)),2)
        # print(x.shape)
        x = F.max_pool2d(self.bn3(self.conv3(x)),2)
        # print(x.shape)
        x = F.max_pool2d(self.bn4(self.conv4(x)),2)
        # print(x.shape)
        x = x.reshape(-1, 32 * 5 * 5)
        # print(x.shape)
        output = self.linear(x)

        return output

def partition_data(x,y,a,state):
    X_train, X_test, y_train, y_test, group_train, group_test = train_test_split(
    x,y,a, test_size=0.2, random_state=0)
    data = {}
    index_of_0_1_train = group_train <= 2
    index_of_0_1_test = group_test <= 2
    data['x'] = X_train[index_of_0_1_train]
    data['a'] = group_train[index_of_0_1_train]
    data['y'] = y_train[index_of_0_1_train]
    data['x_test'] = X_test[index_of_0_1_test]
    data['a_test'] = group_test[index_of_0_1_test]
    data['y_test'] = y_test[index_of_0_1_test]
    all_clients[state] = data
    return

def get_sizes(lst):
    sizes = []
    for w in lst:
        sizes.append(functools.reduce((lambda x, y: x*y), w.size()))
    c = np.cumsum(sizes)
    bounds = list(zip([0] + c[:-1].tolist(), c.tolist()))
    return sizes, bounds

def load_image(img_name):
    img = Image.open(os.path.join(IMAGES_DIR, img_name))
    img = img.resize((IMAGE_SIZE, IMAGE_SIZE)).convert('RGB')
    return np.array(img)

def process_x(raw_x_batch):
    x_batch = [load_image(i) for i in raw_x_batch]
    x_batch = np.array(x_batch)
    return x_batch

def torch_to_numpy(lst, arr=None):
    # lst: obtained either from list(net.parameters()) or from torch.autograd.grad
    lst = list(lst)
    sizes, bounds = get_sizes(lst)
    if arr is None:
        arr = np.zeros(sum(sizes))
    else:
        assert len(arr) == sum(sizes)
    for bound, var in zip(bounds, lst):
        arr[bound[0]: bound[1]] = var.data.cpu().numpy().reshape(-1)
    return arr

def numpy_to_torch(arr, net):
    device = next(net.parameters()).device
    arr = torch.from_numpy(arr).to(device)
    sizes, bounds = get_sizes(net.parameters())
    assert len(arr) == sum(sizes)
    for bound, var in zip(bounds, net.parameters()):
        vnp = var.data.view(-1)
        vnp[:] = arr[bound[0] : bound[1]]
    return net

def setup_clients(model_name=None, model=None, validation=False, seed=-1):
    """Instantiates clients based on given train and test data directories.
        If validation is True, use part of training set as validation set
    Return:
        all_clients: list of Client objects.
    """
    
    clients, groups, train_data, test_data = read_data('data/celeba/data/train','data/celeba/data/test')
    
    if seed != -1:
        np.random.seed(seed)
    else:
        np.random.seed(42)
    print(np.ceil(0.5*len(clients)))

    # print(clients)
    try:
        train_users = clients['train_users'][:int(np.ceil(0.2*len(clients['train_users'])))]
        test_users = clients['test_users'][:int(np.ceil(0.2*len(clients['train_users'])))]
    except:
        train_users = clients[:int(np.ceil(0.2*len(clients)))]
        test_users = clients[:int(np.ceil(0.2*len(clients)))]


    train_groups = [[] for _ in train_users]
    test_groups = [[] for _ in test_users]

    print('------>', len(train_users))

    train_clients = {}
    test_clients = {}
    all_client_name = []
    for u in train_users:
        all_client_name.append(u)
    
    alphas = np.random.randint(low=1, high=21, size=40)
    fraction_of_clients_to_merge = np.random.dirichlet(alphas)
    num_of_clients_to_merge = [np.ceil(x * len(train_users)) for x in fraction_of_clients_to_merge]
    num_of_clients_to_merge[-1] = len(train_users) - sum(num_of_clients_to_merge[:-1])

    silo_infos = {}
    silo_count = 0
    user_count = 1
    users_for_silo = []
    for it, u in enumerate(train_users):
        users_for_silo.append(u)
        if user_count == num_of_clients_to_merge[silo_count]:
            silo_infos[silo_count] = users_for_silo
            user_count = 1
            users_for_silo = []
            silo_count += 1
        else:
            user_count += 1

    for silo in silo_infos.keys():
        silo_x = []
        silo_y = []
        silo_a = []
        for u in silo_infos[silo]:
            silo_x += train_data[u]['x']
            silo_a += train_data[u]['a']
            silo_y += train_data[u]['y']
        # silo_x = process_x(silo_x)
        # silo_x = torch.tensor((silo_x/1.)).cuda()
        # silo_x = torch.transpose(torch.transpose(silo_x,-1,-2),-2,-3)
        # silo_y = torch.LongTensor(silo_y).cuda()
        silo_a = np.array(silo_a)
        silo_data = {'x': silo_x, 'y': silo_y, 'a': silo_a}
        # print(silo_data['x'].shape)
        train_clients[silo] = silo_data
        # train_clients.append(Client(u, g, train_data=train_data_u, model=model, dataset=dataset))

    for silo in silo_infos.keys():
        silo_x = []
        silo_y = []
        silo_a = []
        for u in silo_infos[silo]:
            silo_x += test_data[u]['x']
            silo_a += test_data[u]['a']
            silo_y += test_data[u]['y']
        # silo_x = process_x(silo_x)
        # silo_x = torch.tensor((silo_x/1.)).cuda()
        # silo_x = torch.transpose(torch.transpose(silo_x,-1,-2),-2,-3)
        # silo_y = torch.LongTensor(silo_y).cuda()
        silo_a = np.array(silo_a)
        silo_data = {'x': silo_x, 'y': silo_y, 'a': silo_a}
        # print(silo_data)
        test_clients[silo] = silo_data
        # test_clients.append(Client(u, g,  eval_data=test_data_u, model=model, dataset=dataset))

    all_clients = {
        'train_clients': train_clients,
        'test_clients': test_clients
    }

    return all_clients

def train(client, model, lr, local_iterations,weighted):
    client_model_copy = deepcopy(model)
    optimizer = torch.optim.SGD(client_model_copy.parameters(), lr=lr)
    batched_x, batched_y, batched_a = batch_data(client, 50)
    for j in range(local_iterations):
        for x,y,a in zip(batched_x, batched_y, batched_a):
            optimizer.zero_grad()
            model_output = client_model_copy(x.float())
            model_prediction = model_output.argmax(1)
            if weighted:
                # loss = (nn.CrossEntropyLoss(reduction='mean')(model_output[a == 0], y[a == 0]) + nn.CrossEntropyLoss(reduction='mean')(model_output[a == 1], y[a == 1])) / 2
                loss = 0.
                for p in [0,1]:
                    loss += nn.CrossEntropyLoss(reduction='mean')(model_output[(torch.tensor(a).cuda() == p) & (y == 0)], y[(torch.tensor(a).cuda() == p) & (y == 0)])
                    loss += nn.CrossEntropyLoss(reduction='mean')(model_output[(torch.tensor(a).cuda() == p) & (y == 1)], y[(torch.tensor(a).cuda() == p) & (y == 1)])
                loss /= 4.
            else:
                loss = nn.CrossEntropyLoss()(model_output, y)
            loss.backward()
            optimizer.step()
    return torch_to_numpy(client_model_copy.parameters()) - torch_to_numpy(model.parameters())


def test(client, model, protected_attributes):
    correct = 0
    PP = 0
    total = 0
    PP_by_group = [0]*9
    total_by_group = [0]*9
    running_theta = [0]*9
    batched_x, batched_y, batched_a = batch_data(client, 1000)
    # print(len(batched_x))
    for num_batch,(x,y,a) in enumerate(zip(batched_x, batched_y, batched_a)):
        # print(len(x))
        model.eval()
        model_output = model(x.float())
        model_prediction = model_output.argmax(1)
        PP += model_prediction.sum()

        for idx,p in enumerate(protected_attributes):
            if len(model_output[a == p]) < 1:
                running_theta[idx] += torch.tensor(0)
                continue
            the = nn.CrossEntropyLoss(reduction='sum')(model_output[a == p], y[a == p])
            running_theta[idx] += the
            PP_by_group[idx] += model_prediction[a == p].sum()
            total_by_group[idx] += len(model_prediction[a == p])
        correct += model_prediction.eq(y).sum()
        total += len(y)
    return correct, PP, total, PP_by_group, total_by_group, running_theta



def FedAvg(epochs, client_data, model, lr, local_iterations):
    sys.stdout.flush()
    protected_attributes = [0,1]
    num_pos_instances = 0
    num_neg_instances = 0
    num_test_pos_instances = 0
    num_test_neg_instances = 0
    test_accs = []
    max_sps = []
    weighted = True
    for client in client_data['train_clients'].values():
        running_instance = []
        for p in protected_attributes:
            instance = ((torch.tensor(client['a']) == p) & (np.array(client['y']) == 1)).sum()
            running_instance.append(instance)
        running_instance = np.array(running_instance)
        num_pos_instances += running_instance
    for client in client_data['train_clients'].values():
        running_instance = []
        for p in protected_attributes:
            instance = ((torch.tensor(client['a']) == p) & (np.array(client['y']) == 0)).sum()
            running_instance.append(instance)
        running_instance = np.array(running_instance)
        num_neg_instances += running_instance
    for client in client_data['test_clients'].values():
        running_instance = []
        for p in protected_attributes:
            instance = ((torch.tensor(client['a']) == p) & (np.array(client['y']) == 1)).sum()
            running_instance.append(instance)
        running_instance = np.array(running_instance)
        num_test_pos_instances += running_instance
    for client in client_data['test_clients'].values():
        running_instance = []
        for p in protected_attributes:
            instance = ((torch.tensor(client['a']) == p) & (np.array(client['y'])== 0)).sum()
            running_instance.append(instance)
        running_instance = np.array(running_instance)
        num_test_neg_instances += running_instance
    print(num_pos_instances, num_neg_instances, num_test_pos_instances, num_test_neg_instances)
    
    
    for i in range(epochs):
        sys.stdout.flush()
        if i % 20 == 19:
            lr /= 2
        all_client_updates = 0
        for client in client_data['train_clients'].values():
            model_update = train(client, model, lr, local_iterations,weighted)
            all_client_updates += model_update
            # del batched_x, batched_y, batched_a, client_model_copy
        all_client_updates /= len(client_data['train_clients'])
        w = torch_to_numpy(model.parameters())
        w += all_client_updates * 1
        numpy_to_torch(w, model)
        
        if i == epochs-1:
            print('Printing test accuracy')
            correct = 0
            PP = 0
            total = 0
            PP_by_group = [0]*9
            total_by_group = [0]*9
            running_theta_test = [0]*9
            for client in client_data['test_clients'].values():
                batched_x, batched_y, batched_a = batch_data(client, 100)
                with torch.no_grad():
                    for num_batch, (x,y,a) in enumerate(zip(batched_x, batched_y, batched_a)):
                        model_output = model(x.float())
                        model_prediction = model_output.argmax(1)

                        for idx,p in enumerate(protected_attributes):
                            if len(model_output[(torch.tensor(a) == p).cuda() & (y == 1)]) * len(batched_a) == 0:
                                running_theta_test[idx] += torch.tensor(0)
                                continue
                            the = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(a) == p).cuda() & (y == 1)], y[(torch.tensor(a) == p).cuda() & (y == 1)])
                            running_theta_test[idx] += the

                        for idx,p in enumerate(protected_attributes):
                            if len(model_output[(torch.tensor(a) == p).cuda() & (y == 0)]) * len(batched_a) == 0:
                                running_theta_test[idx+2] += torch.tensor(0)
                                continue
                            the = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(a) == p).cuda() & (y == 0)], y[(torch.tensor(a) == p).cuda() & (y == 0)])
                            running_theta_test[idx+2] += the
                        
                        correct += model_prediction.eq(y).sum()
                        total += len(y)
            print(correct.item()*1.0/total)
            test_accs.append(correct.item()*1.0/total)

            # SP_by_group = [0]*9
            for idx,p in enumerate(protected_attributes):
                running_theta_test[idx] /= num_test_pos_instances[idx]
            for idx,p in enumerate(protected_attributes):
                running_theta_test[idx+2] /= num_test_neg_instances[idx]
            print("TP loss: ", running_theta_test[:2])
            print("FP loss: ", running_theta_test[2:])
        
    return model

def Fair_FedAvg(epochs, client_data, model, protected_attributes, initial_lr, local_iterations, rounds, B, lr_theta, threshold, global_fairness=True):
    theta = torch.tensor(np.zeros(len(protected_attributes))).cuda()
    average_iterate = 0
    iterates = 0
    avg_model = deepcopy(model)
    num_instances = 0
    test_accs = []
    max_sps = []
    num_test_instances = 0
    for client in client_data['train_clients'].values():
        running_instance = []
        for p in protected_attributes:
            instance = (client['a'] == p).sum()
            running_instance.append(instance)
        running_instance = np.array(running_instance)
        num_instances += running_instance

    for client in client_data['test_clients'].values():
        running_instance = []
        for p in protected_attributes:
            instance = (client['a'] == p).sum()
            running_instance.append(instance)
        running_instance = np.array(running_instance)
        num_test_instances += running_instance
    print(num_instances, num_test_instances)
    
    for k in range(rounds):
        lmbda = B*theta.exp()/(1+theta.exp().sum())
        print(lmbda)
        grad_theta = 0
        # if k >= 1:
        #     initial_lr /= 5
        lr = initial_lr
        for i in range(epochs):
            sys.stdout.flush()
            # if i % 5 == 4:
            #     lr /= 2
            all_client_updates = 0
            for client in client_data['train_clients'].values():
                client_model_copy = deepcopy(model)
                client_model_copy.train()
                optimizer = torch.optim.SGD(client_model_copy.parameters(), lr=lr)
                batched_x, batched_y, batched_a = batch_data(client, 50)
                for j in range(local_iterations):
                    for x,y,a in zip(batched_x, batched_y, batched_a):
                        optimizer.zero_grad()
                        model_output = client_model_copy(x.float())
                        loss = 0
                        for p, l, ins in zip(protected_attributes, lmbda, num_instances):
                            if len(model_output[a == p]) * ins * len(batched_a) == 0:
                                continue
                            if global_fairness:
                                subloss = nn.CrossEntropyLoss(reduction='sum')(model_output[a == p], y[a == p]) / ins - threshold / len(batched_a)
                            else:
                                subloss = nn.CrossEntropyLoss(reduction='mean')(model_output[a == p], y[a == p]) - threshold / len(batched_a)
                            # print(subloss)
                            loss += l * subloss
                        # print(loss)
                        loss += nn.CrossEntropyLoss()(model_output, y)
                        loss.backward()
                        optimizer.step()
                all_client_updates += torch_to_numpy(client_model_copy.parameters()) - torch_to_numpy(model.parameters())
            all_client_updates /= len(client_data['train_clients'])
            w = torch_to_numpy(model.parameters())
            w += all_client_updates
            numpy_to_torch(w, model)
            correct = 0
            total = 0
            print('Finishing epoch ', i)
            
            if i == epochs-1:
                for client in client_data['train_clients'].values():
                    running_theta = []
                    batched_x, batched_y, batched_a = batch_data(client, 50)
                    # model.eval()
                    for num_batch, (x,y,a) in enumerate(zip(batched_x, batched_y, batched_a)):
                        with torch.no_grad():
                            model_output = model(x.float())
                            for p_num, (p, ins) in enumerate(zip(protected_attributes, num_instances)):
                                if len(model_output[a == p]) * ins * len(batched_a) == 0:
                                    if num_batch == 0:
                                        running_theta.append(torch.tensor(0).cuda())
                                    else:
                                        running_theta[p_num] += torch.tensor(0).cuda()
                                    continue
                                if global_fairness:
                                    the = nn.CrossEntropyLoss(reduction='sum')(model_output[a == p], y[a == p]) / ins - threshold / len(batched_a)
                                else:
                                    the = nn.CrossEntropyLoss(reduction='mean')(model_output[a == p], y[a == p]) - threshold / len(batched_a)
                                if num_batch == 0:
                                    running_theta.append(the)
                                else:
                                    # print(running_theta[p_num], the)
                                    running_theta[p_num] += the
                            running_theta = torch.tensor(running_theta).cuda()
                            grad_theta = grad_theta + running_theta
                
                iterates += 1
                average_iterate += torch_to_numpy(model.parameters())
                numpy_to_torch(average_iterate/iterates, avg_model)
            
            if i == epochs-1:
                print('Printing test accuracy')
                correct = 0
                PP = 0
                total = 0
                PP_by_group = [0]*9
                total_by_group = [0]*9
                running_theta_test = [0]*9
                for client in client_data['test_clients'].values():
                    batched_x, batched_y, batched_a = batch_data(client, 100)
                    with torch.no_grad():
                        for num_batch, (x,y,a) in enumerate(zip(batched_x, batched_y, batched_a)):
                            model_output = model(x.float())
                            model_prediction = model_output.argmax(1)
                            PP += model_prediction.sum()

                            for idx,p in enumerate(protected_attributes):
                                if len(model_output[a == p]) * ins * len(batched_a) == 0:
                                    running_theta_test[idx] += torch.tensor(0)
                                    continue
                                the = nn.CrossEntropyLoss(reduction='sum')(model_output[a == p], y[a == p])
                                running_theta_test[idx] += the
                                # PP_by_group[idx] += model_prediction[client['a_test'] == p].sum()
                                # total_by_group[idx] += len(model_prediction[client['a_test'] == p])
                            correct += model_prediction.eq(y).sum()
                            total += len(y)
                print(correct.item()*1.0/total)
                test_accs.append(correct.item()*1.0/total)

                # SP_by_group = [0]*9
                for idx,p in enumerate(protected_attributes):
                    # SP_by_group[idx] = (PP_by_group[idx] / total_by_group[idx]).item()
                    running_theta_test[idx] /= num_test_instances[idx]
                print(running_theta_test)
                # max_sps.append(running_theta_test)
                # print('Max SP gap: ', max(SP_by_group)-min(SP_by_group))
                # print('SP std: ', np.array(SP_by_group).std())
        
        for idx,p in enumerate(protected_attributes):
            grad_theta[idx] = grad_theta[idx] / num_instances[idx]
        theta += lr_theta * grad_theta
#         print(theta, grad_theta)
      
    return avg_model, test_accs

def Fair_FedAvg_DP_difference(epochs, client_data, model, protected_attributes, initial_lr, local_iterations, rounds, B, lr_theta, threshold, global_fairness=True):
    theta = torch.tensor(0.).cuda()
    average_iterate = 0
    iterates = 0
    avg_model = deepcopy(model)
    num_instances = 0
    test_accs = []
    max_sps = []
    num_test_instances = 0
    for client in client_data['train_clients'].values():
        running_instance = []
        for p in protected_attributes:
            instance = (client['a'] == p).sum()
            running_instance.append(instance)
        running_instance = np.array(running_instance)
        num_instances += running_instance

    for client in client_data['test_clients'].values():
        running_instance = []
        for p in protected_attributes:
            instance = (client['a'] == p).sum()
            running_instance.append(instance)
        running_instance = np.array(running_instance)
        num_test_instances += running_instance
    print(num_instances, num_test_instances)
    
    for k in range(rounds):
        lmbda = B*theta.exp()/(1+theta.exp().sum())
        print(lmbda)
        grad_theta = 0
        # if k >= 1:
        #     initial_lr /= 5
        lr = initial_lr
        for i in range(epochs):
            sys.stdout.flush()
            # if i % 5 == 4:
            #     lr /= 2
            all_client_updates = 0
            for client in client_data['train_clients'].values():
                client_model_copy = deepcopy(model)
                client_model_copy.train()
                optimizer = torch.optim.SGD(client_model_copy.parameters(), lr=lr)
                batched_x, batched_y, batched_a = batch_data(client, 50)
                for j in range(local_iterations):
                    for x,y,a in zip(batched_x, batched_y, batched_a):
                        optimizer.zero_grad()
                        model_output = client_model_copy(x.float())
                        loss = 0
                        subloss = []
                        for p, ins in zip(protected_attributes, num_instances):
                            if len(model_output[a == p]) * ins * len(batched_a) == 0:
                                subloss.append(torch.tensor(0).cuda())
                            else:
                                subloss.append(nn.CrossEntropyLoss(reduction='sum')(model_output[a == p], y[a == p]) / ins)
                            # else:
                            #     subloss = nn.CrossEntropyLoss(reduction='mean')(model_output[a == p], y[a == p]) - threshold / len(batched_a)
                            # print(subloss)
                            # loss += l * subloss
                        # print(loss)
                        loss += lmbda * torch.abs(subloss[0] - subloss[1])
                        loss += nn.CrossEntropyLoss()(model_output, y)
                        loss.backward()
                        optimizer.step()
                all_client_updates += torch_to_numpy(client_model_copy.parameters()) - torch_to_numpy(model.parameters())
            all_client_updates /= len(client_data['train_clients'])
            w = torch_to_numpy(model.parameters())
            w += all_client_updates
            numpy_to_torch(w, model)
            correct = 0
            total = 0
            print('Finishing epoch ', i)
            
            if i == epochs-1:
                for client in client_data['train_clients'].values():
                    running_theta = []
                    batched_x, batched_y, batched_a = batch_data(client, 50)
                    # model.eval()
                    for num_batch, (x,y,a) in enumerate(zip(batched_x, batched_y, batched_a)):
                        with torch.no_grad():
                            model_output = model(x.float())
                            for p_num, (p, ins) in enumerate(zip(protected_attributes, num_instances)):
                                if len(model_output[a == p]) * ins * len(batched_a) == 0:
                                    if num_batch == 0:
                                        running_theta.append(torch.tensor(0).cuda())
                                    else:
                                        running_theta[p_num] += torch.tensor(0).cuda()
                                    continue
                                
                                the = nn.CrossEntropyLoss(reduction='sum')(model_output[a == p], y[a == p]) / ins
                                if num_batch == 0:
                                    running_theta.append(the)
                                else:
                                    # print(running_theta[p_num], the)
                                    running_theta[p_num] += the
                            running_theta = torch.tensor(running_theta).cuda()
                            grad_theta = grad_theta + running_theta
                
                iterates += 1
                average_iterate += torch_to_numpy(model.parameters())
                numpy_to_torch(average_iterate/iterates, avg_model)
            
            if i == epochs-1:
                print('Printing test accuracy')
                correct = 0
                PP = 0
                total = 0
                PP_by_group = [0]*9
                total_by_group = [0]*9
                running_theta_test = [0]*9
                for client in client_data['test_clients'].values():
                    batched_x, batched_y, batched_a = batch_data(client, 100)
                    with torch.no_grad():
                        for num_batch, (x,y,a) in enumerate(zip(batched_x, batched_y, batched_a)):
                            model_output = model(x.float())
                            model_prediction = model_output.argmax(1)
                            PP += model_prediction.sum()

                            for idx,p in enumerate(protected_attributes):
                                if len(model_output[a == p]) * ins * len(batched_a) == 0:
                                    running_theta_test[idx] += torch.tensor(0)
                                    continue
                                the = nn.CrossEntropyLoss(reduction='sum')(model_output[a == p], y[a == p])
                                running_theta_test[idx] += the
                                # PP_by_group[idx] += model_prediction[client['a_test'] == p].sum()
                                # total_by_group[idx] += len(model_prediction[client['a_test'] == p])
                            correct += model_prediction.eq(y).sum()
                            total += len(y)
                print(correct.item()*1.0/total)
                test_accs.append(correct.item()*1.0/total)

                # SP_by_group = [0]*9
                for idx,p in enumerate(protected_attributes):
                    # SP_by_group[idx] = (PP_by_group[idx] / total_by_group[idx]).item()
                    running_theta_test[idx] /= num_test_instances[idx]
                print(running_theta_test)
                # max_sps.append(running_theta_test)
                # print('Max SP gap: ', max(SP_by_group)-min(SP_by_group))
                # print('SP std: ', np.array(SP_by_group).std())
        
        for idx,p in enumerate(protected_attributes):
            grad_theta[idx] = grad_theta[idx] / num_instances[idx]
        theta += lr_theta * torch.abs(grad_theta[0] - grad_theta[1])
#         print(theta, grad_theta)
      
    return avg_model, test_accs

def Fair_FedAvg_TP_FP(epochs, client_data, model, protected_attributes, initial_lr, local_iterations, rounds, B, lr_theta, pos_threshold, neg_threshold, global_fairness=True):
    TP = True
    FP = False
    if TP and FP:
        theta = torch.tensor(np.zeros(len(protected_attributes)*2)).cuda()
    else:
        theta = torch.tensor(np.zeros(len(protected_attributes))).cuda()
    average_iterate = 0
    iterates = 0
    avg_model = deepcopy(model)
    num_pos_instances = 0
    num_neg_instances = 0
    num_test_pos_instances = 0
    num_test_neg_instances = 0
    test_accs = []
    max_sps = []
    for client in client_data['train_clients'].values():
        running_instance = []
        for p in protected_attributes:
            instance = ((torch.tensor(client['a']) == p) & (np.array(client['y']) == 1)).sum()
            running_instance.append(instance)
        running_instance = np.array(running_instance)
        num_pos_instances += running_instance
    for client in client_data['train_clients'].values():
        running_instance = []
        for p in protected_attributes:
            instance = ((torch.tensor(client['a']) == p) & (np.array(client['y']) == 0)).sum()
            running_instance.append(instance)
        running_instance = np.array(running_instance)
        num_neg_instances += running_instance
    for client in client_data['test_clients'].values():
        running_instance = []
        for p in protected_attributes:
            instance = ((torch.tensor(client['a']) == p) & (np.array(client['y']) == 1)).sum()
            running_instance.append(instance)
        running_instance = np.array(running_instance)
        num_test_pos_instances += running_instance
    for client in client_data['test_clients'].values():
        running_instance = []
        for p in protected_attributes:
            instance = ((torch.tensor(client['a']) == p) & (np.array(client['y'])== 0)).sum()
            running_instance.append(instance)
        running_instance = np.array(running_instance)
        num_test_neg_instances += running_instance
    print(num_pos_instances, num_neg_instances, num_test_pos_instances, num_test_neg_instances)

    for k in range(rounds):
        print(B)
        lmbda = B*theta.exp()/(1+theta.exp().sum())
        print(lmbda)
        grad_theta = 0
        if k >= 1:
            initial_lr /= 2
        lr = initial_lr
        for i in range(epochs):
            sys.stdout.flush()
            # if i % 5 == 4:
            #     lr /= 2
            all_client_updates = 0
            for client in client_data['train_clients'].values():
                client_model_copy = deepcopy(model)
                client_model_copy.train()
                optimizer = torch.optim.SGD(client_model_copy.parameters(), lr=lr)
                batched_x, batched_y, batched_a = batch_data(client, 50)
                for j in range(local_iterations):
                    for x,y,a in zip(batched_x, batched_y, batched_a):
                        optimizer.zero_grad()
                        model_output = client_model_copy(x.float())
                        loss = 0
                        if TP:
                            for p, l, ins in zip(protected_attributes, lmbda[:2], num_pos_instances):
                                if len(model_output[(torch.tensor(a) == p).cuda() & (y == 1)]) * ins * len(batched_a) == 0:
                                    continue                       
                                subloss = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(a).cuda() == p) & (y == 1)], y[(torch.tensor(a).cuda() == p) & (y == 1)]) / ins - pos_threshold / len(batched_a)  
                                loss += l * subloss
                        if FP:
                            if TP:
                                target_dual = lmbda[2:]
                            else:
                                target_dual = lmbda[:2]
                            for p, l, ins in zip(protected_attributes, target_dual, num_neg_instances):
                                if len(model_output[(torch.tensor(a) == p).cuda() & (y == 0)]) * ins * len(batched_a) == 0:
                                    continue                        
                                subloss = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(a) == p).cuda() & (y == 0)], y[(torch.tensor(a) == p).cuda() & (y == 0)]) / ins - neg_threshold / len(batched_a)  
                                loss += l * subloss
                        loss += nn.CrossEntropyLoss()(model_output, y)
                        loss.backward()
                        optimizer.step()
                all_client_updates += torch_to_numpy(client_model_copy.parameters()) - torch_to_numpy(model.parameters())
            all_client_updates /= len(client_data['train_clients'])
            w = torch_to_numpy(model.parameters())
            w += all_client_updates
            numpy_to_torch(w, model)
            correct = 0
            total = 0
            print('Finishing epoch ', i)
            
            if i == epochs-1:
                for client in client_data['train_clients'].values():
                    running_theta = []
                    batched_x, batched_y, batched_a = batch_data(client, 50)
                    # model.eval()
                    for num_batch, (x,y,a) in enumerate(zip(batched_x, batched_y, batched_a)):
                        with torch.no_grad():
                            model_output = model(x.float())
                            if TP:
                                for p_num, (p, ins) in enumerate(zip(protected_attributes, num_pos_instances)):
                                    if len(model_output[(torch.tensor(a) == p).cuda() & (y == 1)]) * ins * len(batched_a) == 0:
                                        if num_batch == 0:
                                            running_theta.append(torch.tensor(0).cuda())
                                        else:
                                            running_theta[p_num] += torch.tensor(0).cuda()
                                        continue

                                    the = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(a) == p).cuda() & (y == 1)], y[(torch.tensor(a) == p).cuda() & (y == 1)]) / ins - pos_threshold / len(batched_a)
                                    
                                    if num_batch == 0:
                                        running_theta.append(the)
                                    else:
                                        # print(running_theta[p_num], the)
                                        running_theta[p_num] += the
                            if FP:
                                for p_num, (p, ins) in enumerate(zip(protected_attributes, num_neg_instances)):
                                    if TP:
                                        p_num = p_num + 2
                                    if len(model_output[(torch.tensor(a) == p).cuda() & (y == 0)]) * ins * len(batched_a) == 0:
                                        if num_batch == 0:
                                            running_theta.append(torch.tensor(0).cuda())
                                        else:
                                            running_theta[p_num] += torch.tensor(0).cuda()
                                        continue

                                    the = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(a) == p).cuda() & (y == 0)], y[(torch.tensor(a) == p).cuda() & (y == 0)]) / ins - neg_threshold / len(batched_a)
                                    
                                    if num_batch == 0:
                                        running_theta.append(the)
                                    else:
                                        # print(running_theta[p_num], the)
                                        running_theta[p_num] += the
                                                         
                            running_theta = torch.tensor(running_theta).cuda()
                            grad_theta = grad_theta + running_theta
                
                iterates += 1
                average_iterate += torch_to_numpy(model.parameters())
                numpy_to_torch(average_iterate/iterates, avg_model)
            
            if i == epochs-1:
                print('Printing test accuracy')
                correct = 0
                PP = 0
                total = 0
                PP_by_group = [0]*9
                total_by_group = [0]*9
                running_theta_test = [0]*9
                for client in client_data['test_clients'].values():
                    batched_x, batched_y, batched_a = batch_data(client, 100)
                    with torch.no_grad():
                        for num_batch, (x,y,a) in enumerate(zip(batched_x, batched_y, batched_a)):
                            model_output = model(x.float())
                            model_prediction = model_output.argmax(1)

                            for idx,p in enumerate(protected_attributes):
                                if len(model_output[(torch.tensor(a) == p).cuda() & (y == 1)]) * ins * len(batched_a) == 0:
                                    running_theta_test[idx] += torch.tensor(0)
                                    continue
                                the = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(a) == p).cuda() & (y == 1)], y[(torch.tensor(a) == p).cuda() & (y == 1)])
                                running_theta_test[idx] += the

                            for idx,p in enumerate(protected_attributes):
                                if len(model_output[(torch.tensor(a) == p).cuda() & (y == 0)]) * ins * len(batched_a) == 0:
                                    running_theta_test[idx+2] += torch.tensor(0)
                                    continue
                                the = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(a) == p).cuda() & (y == 0)], y[(torch.tensor(a) == p).cuda() & (y == 0)])
                                running_theta_test[idx+2] += the
                            
                            correct += model_prediction.eq(y).sum()
                            total += len(y)
                print(correct.item()*1.0/total)
                test_accs.append(correct.item()*1.0/total)

                # SP_by_group = [0]*9
                for idx,p in enumerate(protected_attributes):
                    running_theta_test[idx] /= num_test_pos_instances[idx]
                for idx,p in enumerate(protected_attributes):
                    running_theta_test[idx+2] /= num_test_neg_instances[idx]
                print("TP loss: ", running_theta_test[:2])
                print("FP loss: ", running_theta_test[2:])
                # max_sps.append(running_theta_test)
                # max_sps.append(running_theta_test)
                # print('Max SP gap: ', max(SP_by_group)-min(SP_by_group))
                # print('SP std: ', np.array(SP_by_group).std())
        
        theta += lr_theta * grad_theta
#         print(theta, grad_theta)
      
    return avg_model, test_accs

def Fair_FedAvg_TP_FP_local(epochs, client_data, model, protected_attributes, initial_lr, local_iterations, rounds, B, lr_theta, pos_threshold, neg_threshold, global_fairness=True):
    TP = True
    FP = False
    if TP and FP:
        theta = torch.tensor(np.zeros(len(protected_attributes)*2)).cuda()
    else:
        theta = torch.tensor(np.zeros(len(protected_attributes))).cuda()
    average_iterate = 0
    iterates = 0
    avg_model = deepcopy(model)
    num_pos_instances = 0
    num_neg_instances = 0
    num_test_pos_instances = 0
    num_test_neg_instances = 0
    test_accs = []
    max_sps = []
    for client in client_data['train_clients'].values():
        running_instance = []
        for p in protected_attributes:
            instance = ((torch.tensor(client['a']) == p) & (np.array(client['y']) == 1)).sum()
            running_instance.append(instance)
        running_instance = np.array(running_instance)
        num_pos_instances += running_instance
    for client in client_data['train_clients'].values():
        running_instance = []
        for p in protected_attributes:
            instance = ((torch.tensor(client['a']) == p) & (np.array(client['y']) == 0)).sum()
            running_instance.append(instance)
        running_instance = np.array(running_instance)
        num_neg_instances += running_instance
    for client in client_data['test_clients'].values():
        running_instance = []
        for p in protected_attributes:
            instance = ((torch.tensor(client['a']) == p) & (np.array(client['y']) == 1)).sum()
            running_instance.append(instance)
        running_instance = np.array(running_instance)
        num_test_pos_instances += running_instance
    for client in client_data['test_clients'].values():
        running_instance = []
        for p in protected_attributes:
            instance = ((torch.tensor(client['a']) == p) & (np.array(client['y'])== 0)).sum()
            running_instance.append(instance)
        running_instance = np.array(running_instance)
        num_test_neg_instances += running_instance
    print(num_pos_instances, num_neg_instances, num_test_pos_instances, num_test_neg_instances)

    for k in range(rounds):
        print(B)
        lmbda = B*theta.exp()/(1+theta.exp().sum())
        print(lmbda)
        grad_theta = 0
        if k >= 1:
            initial_lr /= 2
        lr = initial_lr
        for i in range(epochs):
            sys.stdout.flush()
            # if i % 5 == 4:
            #     lr /= 2
            all_client_updates = 0
            for client in client_data['train_clients'].values():
                client_model_copy = deepcopy(model)
                client_model_copy.train()
                optimizer = torch.optim.SGD(client_model_copy.parameters(), lr=lr)
                batched_x, batched_y, batched_a = batch_data(client, 50)
                for j in range(local_iterations):
                    for x,y,a in zip(batched_x, batched_y, batched_a):
                        optimizer.zero_grad()
                        model_output = client_model_copy(x.float())
                        loss = 0
                        if TP:
                            for p, l, ins in zip(protected_attributes, lmbda[:2], num_pos_instances):
                                if len(model_output[(torch.tensor(a) == p).cuda() & (y == 1)]) * ins * len(batched_a) == 0:
                                    continue                       
                                subloss = nn.CrossEntropyLoss(reduction='mean')(model_output[(torch.tensor(a).cuda() == p) & (y == 1)], y[(torch.tensor(a).cuda() == p) & (y == 1)]) - pos_threshold / len(batched_a)  
                                loss += l * subloss
                        if FP:
                            if TP:
                                target_dual = lmbda[2:]
                            else:
                                target_dual = lmbda[:2]
                            for p, l, ins in zip(protected_attributes, target_dual, num_neg_instances):
                                if len(model_output[(torch.tensor(a) == p).cuda() & (y == 0)]) * ins * len(batched_a) == 0:
                                    continue                        
                                subloss = nn.CrossEntropyLoss(reduction='mean')(model_output[(torch.tensor(a) == p).cuda() & (y == 0)], y[(torch.tensor(a) == p).cuda() & (y == 0)]) - neg_threshold / len(batched_a)  
                                loss += l * subloss
                        loss += nn.CrossEntropyLoss()(model_output, y)
                        loss.backward()
                        optimizer.step()
                all_client_updates += torch_to_numpy(client_model_copy.parameters()) - torch_to_numpy(model.parameters())
            all_client_updates /= len(client_data['train_clients'])
            w = torch_to_numpy(model.parameters())
            w += all_client_updates
            numpy_to_torch(w, model)
            correct = 0
            total = 0
            print('Finishing epoch ', i)
            
            if i == epochs-1:
                for client in client_data['train_clients'].values():
                    running_theta = []
                    batched_x, batched_y, batched_a = batch_data(client, 50)
                    # model.eval()
                    for num_batch, (x,y,a) in enumerate(zip(batched_x, batched_y, batched_a)):
                        with torch.no_grad():
                            model_output = model(x.float())
                            if TP:
                                for p_num, (p, ins) in enumerate(zip(protected_attributes, num_pos_instances)):
                                    if len(model_output[(torch.tensor(a) == p).cuda() & (y == 1)]) * ins * len(batched_a) == 0:
                                        if num_batch == 0:
                                            running_theta.append(torch.tensor(0).cuda())
                                        else:
                                            running_theta[p_num] += torch.tensor(0).cuda()
                                        continue

                                    the = nn.CrossEntropyLoss(reduction='mean')(model_output[(torch.tensor(a) == p).cuda() & (y == 1)], y[(torch.tensor(a) == p).cuda() & (y == 1)]) - pos_threshold / len(batched_a)
                                    
                                    if num_batch == 0:
                                        running_theta.append(the)
                                    else:
                                        # print(running_theta[p_num], the)
                                        running_theta[p_num] += the
                            if FP:
                                for p_num, (p, ins) in enumerate(zip(protected_attributes, num_neg_instances)):
                                    if TP:
                                        p_num = p_num + 2
                                    if len(model_output[(torch.tensor(a) == p).cuda() & (y == 0)]) * ins * len(batched_a) == 0:
                                        if num_batch == 0:
                                            running_theta.append(torch.tensor(0).cuda())
                                        else:
                                            running_theta[p_num] += torch.tensor(0).cuda()
                                        continue

                                    the = nn.CrossEntropyLoss(reduction='mean')(model_output[(torch.tensor(a) == p).cuda() & (y == 0)], y[(torch.tensor(a) == p).cuda() & (y == 0)]) - neg_threshold / len(batched_a)
                                    
                                    if num_batch == 0:
                                        running_theta.append(the)
                                    else:
                                        # print(running_theta[p_num], the)
                                        running_theta[p_num] += the
                                                         
                            running_theta = torch.tensor(running_theta).cuda()
                            grad_theta = grad_theta + running_theta
                
                iterates += 1
                average_iterate += torch_to_numpy(model.parameters())
                numpy_to_torch(average_iterate/iterates, avg_model)
            
            if i == epochs-1:
                print('Printing test accuracy')
                correct = 0
                PP = 0
                total = 0
                PP_by_group = [0]*9
                total_by_group = [0]*9
                running_theta_test = [0]*9
                for client in client_data['test_clients'].values():
                    batched_x, batched_y, batched_a = batch_data(client, 100)
                    with torch.no_grad():
                        for num_batch, (x,y,a) in enumerate(zip(batched_x, batched_y, batched_a)):
                            model_output = model(x.float())
                            model_prediction = model_output.argmax(1)

                            for idx,p in enumerate(protected_attributes):
                                if len(model_output[(torch.tensor(a) == p).cuda() & (y == 1)]) * ins * len(batched_a) == 0:
                                    running_theta_test[idx] += torch.tensor(0)
                                    continue
                                the = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(a) == p).cuda() & (y == 1)], y[(torch.tensor(a) == p).cuda() & (y == 1)])
                                running_theta_test[idx] += the

                            for idx,p in enumerate(protected_attributes):
                                if len(model_output[(torch.tensor(a) == p).cuda() & (y == 0)]) * ins * len(batched_a) == 0:
                                    running_theta_test[idx+2] += torch.tensor(0)
                                    continue
                                the = nn.CrossEntropyLoss(reduction='sum')(model_output[(torch.tensor(a) == p).cuda() & (y == 0)], y[(torch.tensor(a) == p).cuda() & (y == 0)])
                                running_theta_test[idx+2] += the
                            
                            correct += model_prediction.eq(y).sum()
                            total += len(y)
                print(correct.item()*1.0/total)
                test_accs.append(correct.item()*1.0/total)

                # SP_by_group = [0]*9
                for idx,p in enumerate(protected_attributes):
                    running_theta_test[idx] /= num_test_pos_instances[idx]
                for idx,p in enumerate(protected_attributes):
                    running_theta_test[idx+2] /= num_test_neg_instances[idx]
                print("TP loss: ", running_theta_test[:2])
                print("FP loss: ", running_theta_test[2:])
                # max_sps.append(running_theta_test)
                # max_sps.append(running_theta_test)
                # print('Max SP gap: ', max(SP_by_group)-min(SP_by_group))
                # print('SP std: ', np.array(SP_by_group).std())
        
        theta += lr_theta * grad_theta
#         print(theta, grad_theta)
      
    return avg_model, test_accs


def FedMinMax(epochs, client_data, model, protected_attributes, initial_lr, local_iterations, rounds, B, lr_theta, threshold):
    theta = torch.tensor(np.ones(len(protected_attributes))).cuda()*1.0 / len(protected_attributes)
    average_iterate = 0
    iterates = 0
    avg_model = deepcopy(model)
    num_instances = 0
    test_accs = []
    max_sps = []
    num_test_instances = 0
    for client in client_data['train_clients'].values():
        running_instance = []
        for p in protected_attributes:
            instance = (client['a'] == p).sum()
            running_instance.append(instance)
        running_instance = np.array(running_instance)
        num_instances += running_instance

    for client in client_data['test_clients'].values():
        running_instance = []
        for p in protected_attributes:
            instance = (client['a'] == p).sum()
            running_instance.append(instance)
        running_instance = np.array(running_instance)
        num_test_instances += running_instance
    print(num_instances, num_test_instances)
    
    for i in range(epochs):
        sys.stdout.flush()
        lr = initial_lr
        # if i % 20 == 19:
        #     lr /= 2
        lmbda = theta
        all_client_updates = 0
        for client in client_data['train_clients'].values():
            client_model_copy = deepcopy(model)
            optimizer = torch.optim.SGD(client_model_copy.parameters(), lr=lr)
            batched_x, batched_y, batched_a = batch_data(client, 50)
            for j in range(local_iterations):
                for x,y,a in zip(batched_x, batched_y, batched_a):
                    optimizer.zero_grad()
                    model_output = client_model_copy(x.float())
                    loss = 0
                    for p, l, ins in zip(protected_attributes, lmbda, num_instances):
                        if len(model_output[a == p]) * ins * len(batched_a) == 0:
                            continue                  
                        loss += l * (nn.CrossEntropyLoss(reduction='sum')(model_output[a == p], y[a == p]) / ins)  
                    loss.backward()
                    optimizer.step()
            all_client_updates += torch_to_numpy(client_model_copy.parameters()) - torch_to_numpy(model.parameters())
        all_client_updates /= len(client_data['train_clients'])
        w = torch_to_numpy(model.parameters())
        w += all_client_updates
        numpy_to_torch(w, model)
        correct = 0
        total = 0
        grad_theta = 0

        for client in client_data['train_clients'].values():
            running_theta = []
            batched_x, batched_y, batched_a = batch_data(client, 50)
            # model.eval()
            for num_batch, (x,y,a) in enumerate(zip(batched_x, batched_y, batched_a)):
                with torch.no_grad():
                    model_output = model(x.float())
                    for p_num, (p, ins) in enumerate(zip(protected_attributes, num_instances)):
                        if len(model_output[a == p]) * ins * len(batched_a) == 0:
                            if num_batch == 0:
                                running_theta.append(torch.tensor(0).cuda())
                            else:
                                running_theta[p_num] += torch.tensor(0).cuda()
                            continue
                        the = nn.CrossEntropyLoss(reduction='sum')(model_output[a == p], y[a == p]) / ins
                        if num_batch == 0:
                            running_theta.append(the)
                        else:
                            # print(running_theta[p_num], the)
                            running_theta[p_num] += the
                    running_theta = torch.tensor(running_theta).cuda()
                    grad_theta = grad_theta + running_theta

        iterates += 1
        average_iterate += torch_to_numpy(model.parameters())
        numpy_to_torch(average_iterate/iterates, avg_model)

        if i == epochs-1:
            print('Printing test accuracy')
            correct = 0
            PP = 0
            total = 0
            PP_by_group = [0]*9
            total_by_group = [0]*9
            running_theta_test = [0]*9
            for client in client_data['test_clients'].values():
                batched_x, batched_y, batched_a = batch_data(client, 100)
                with torch.no_grad():
                    for num_batch, (x,y,a) in enumerate(zip(batched_x, batched_y, batched_a)):
                        model_output = model(x.float())
                        model_prediction = model_output.argmax(1)
                        PP += model_prediction.sum()

                        for idx,p in enumerate(protected_attributes):
                            if len(model_output[a == p]) * ins * len(batched_a) == 0:
                                running_theta_test[idx] += torch.tensor(0)
                                continue
                            the = nn.CrossEntropyLoss(reduction='sum')(model_output[a == p], y[a == p])
                            running_theta_test[idx] += the
                        correct += model_prediction.eq(y).sum()
                        total += len(y)
            print(correct.item()*1.0/total)
            test_accs.append(correct.item()*1.0/total)

            # SP_by_group = [0]*9
            for idx,p in enumerate(protected_attributes):
                running_theta_test[idx] /= num_test_instances[idx]
            print(running_theta_test)

        for idx,p in enumerate(protected_attributes):
            grad_theta[idx] = grad_theta[idx] / num_instances[idx]
        theta += lr_theta * grad_theta
        theta /= theta.sum()
        
        
    return avg_model, test_accs, max_sps

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--epochs',
                        type=int,
                        default=100)

    parser.add_argument('--lr',
                        type=float,
                        default=0.1)

    parser.add_argument('--global_lr',
                        type=float,
                        default=1)

    parser.add_argument('--threshold',
                        type=float,
                        default=0)

    parser.add_argument('--pos_threshold',
                        type=float,
                        default=0)
    
    parser.add_argument('--neg_threshold',
                        type=float,
                        default=0)

    parser.add_argument('--lr_theta',
                        type=float,
                        default=0.1)

    parser.add_argument('--local_iterations',
                        type=int,
                        default=2)

    parser.add_argument('--rounds',
                        type=int,
                        default=10)

    parser.add_argument('--B',
                        type=float,
                        default=0.1)

    args = parser.parse_args()
    # if args.seed is None:
    #     args.seed = random.randint(0, 2**32 - 2)
    #     print('Random seed not provided. Using {} as seed'.format(args.seed))

    epochs = args.epochs
    lr = args.lr
    threshold = args.threshold
    global_lr = args.global_lr
    lr_theta = args.lr_theta
    local_iterations = args.local_iterations
    rounds = args.rounds
    B = args.B
    pos_threshold = args.pos_threshold
    neg_threshold = args.neg_threshold

    model = CNN().cuda()
    print('Model loaded')

    all_clients = setup_clients()
    print('All data loaded')

    # FedAvg(epochs, all_clients, model, lr, local_iterations)
    # _, test_accs, max_sps = Fair_FedAvg(epochs, all_clients, model, [0,1], lr, local_iterations,rounds, B, lr_theta, threshold)
    _, test_accs, max_sps = Fair_FedAvg_DP_difference(epochs, all_clients, model, [0,1], lr, local_iterations,rounds, B, lr_theta, threshold)

    # _, test_accs = Fair_FedAvg_TP_FP(epochs, all_clients, model, [0,1], lr, local_iterations,rounds, B, lr_theta, pos_threshold, neg_threshold)
    # _, test_accs = Fair_FedAvg_TP_FP_local(epochs, all_clients, model, [0,1], lr, local_iterations,rounds, B, lr_theta, pos_threshold, neg_threshold)
    # _, test_accs, max_sps = FedMinMax(epochs, all_clients, model, [0,1], lr, local_iterations,rounds, B, lr_theta, threshold)




