# coding=utf-8
"""Fairness with noisy protected groups experiments."""
import argparse
import os
import yaml
from pathlib import Path

import numpy as np
from sklearn.cluster import KMeans
import torch
import torch.nn as nn
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter

from utils import generate_proxy_groups_uniform, generate_proxy_groups_noise_array,\
                  generate_proxy_groups_single_noise, error_rate, group_error_rates, \
                  tpr, group_tprs,increment_dir, dataset_iid, average_weights, update_ykt_batch, update_x_k, update_x_k_fedavg
from data import load_data_adult
from model import Linear
from update import adjust_curlr, LocalUpdateDDRO, FedAvg, DatasetSplit

def violation(
    labels, predictions, epsilon, groups):
  # Returns violations across different group feature thresholds.
  viol_list = []
  overall_tpr = tpr(labels, predictions).cpu()
  for kk in range(groups.shape[1]):
    group_tpr =group_tprs(labels, predictions, groups[:,kk]).cpu()
    viol_list += [overall_tpr - group_tpr - epsilon]
  return np.max(viol_list), viol_list

def lagrangian_loss(labels, predictions, epsilon, 
                    groups, criterion, device, gamma=0.1, alpha=0.1):
  # Returns the lagrangain loss of the constrained problem
  total_loss = torch.zeros(1).to(device)
  main_loss = criterion(predictions, labels)
  m = groups.shape[1]
  for kk in range(m):
    g_inds = groups[:,kk].logical_and((labels >= 1.0).squeeze())
    group_loss = criterion(predictions[g_inds],labels[g_inds])
    const_loss = group_loss - main_loss - epsilon
    total_loss += 1/(m+1) * (1/m + torch.exp(const_loss * alpha / gamma))
  return total_loss

def train_DDRO(args, device=None):
  """Training model"""
  tb_writer = SummaryWriter(log_dir=args.log_dir)
  if device is None:
    device = torch.device('cuda')

  ## fix random seeds
  torch.manual_seed(args.random_seed)
  np.random.seed(args.random_seed)
  #criterion = nn.BCEWithLogitsLoss()
  train_data, test_data = load_data_adult(args.noise_level, 
                                          uniform_groups=args.uniform_groups, 
                                          min_group_frac=args.min_group_frac, 
                                          use_noise_array=args.use_noise_array,
                                          group_features_type=args.group_features_type, 
                                          num_group_clusters=args.num_group_clusters)

  #train_data.to(device)
  #test_data.to(device)

  args.rounds = len(train_data)//args.batch_size * args.epochs * args.I
  print(args.rounds)
  args.local_bs = args.batch_size//args.num_users

  ##split dataset for different clinets (users) in FL setting
  idxs_users = list(range(args.num_users))
  # split the whole train dataset for each user
  dict_users = dataset_iid(dataset=train_data, num_users=args.num_users)
  dict_train_loaders = [torch.utils.data.DataLoader(DatasetSplit(train_data, dict_users[idx]),
                                                    batch_size=args.local_bs, shuffle=True, 
                                                    num_workers=0, pin_memory=False) 
                          for idx in idxs_users]
  y_kt = [0.0 for i in idxs_users]

  model = Linear(train_data.num_features).to(device)
  weights = model.state_dict()
  local_weights_pre = [deepcopy(weights) for i in range(len(idxs_users))]
  local_weights_cur = [deepcopy(weights) for i in range(len(idxs_users))]

  #optim =  torch.optim.Adagrad(model.parameters(), lr=args.lr)
 
  com = 0
  # FL trining loop with rounds
  for round in range(args.rounds):
    epoch = (round * args.batch_size) // len(train_data)  

    if round == 0:
      epoch_old = epoch
    #adjust_curlr_beta(epoch, args)
    #print(args.curlr)
    model.train()
    # for local updates
    for idx in idxs_users:
      # models
      model_pre = deepcopy(model)
      model_pre.load_state_dict(local_weights_pre[idx])
      model_cur = deepcopy(model)
      model_cur.load_state_dict(local_weights_cur[idx])
      

      if args.resample_proxy_groups:
        # Only resample proxy groups every epochs_per_resample epochs.
        if round % args.epochs_per_resample == 0:
          # Resample the group at the beginning of the epoch.
          # Get groups_train from a ball around init_proxy_groups_train.
          if args.uniform_groups:
            train_data.proxy_groups_tensor = generate_proxy_groups_uniform(
                len(train_data), min_group_frac=args.min_group_frac).long().to(device)
          elif args.use_noise_array:
            train_data.noise_array = train_data.get_noise_array()
            train_data.proxy_groups_tensor = generate_proxy_groups_noise_array(
                train_data.proxy_groups_tensor, noise_array=train_data.noise_array).long().to(device)
          else:
            train_data.proxy_groups_tensor = generate_proxy_groups_single_noise(
                train_data.proxy_groups_tensor, noise_param=args.noise_level).long().to(device)
            
      # data for client idx
      # train_data_idx = DatasetSplit(train_data, dict_users[idx])
      # train_loader_idx = torch.utils.data.DataLoader(train_data_idx, batch_size=args.batch_size, shuffle=True, pin_memory=False)
      # data loader for client idx
      train_loader_idx = dict_train_loaders[idx]
      # a random batch of train data for client idx
      batch = next(iter(train_loader_idx))
      # update ykt just with a batch traindata
      ykt_idx = update_ykt_batch(model_pre = model_pre, model_cur= model_cur,
                            global_round=epoch, 
                            ykt=y_kt[idx],
                            batch = batch, 
                            beta=args.curbeta, # cy
                            lmbda= args.lamda)
      y_kt[idx] = ykt_idx
    y_t = sum(y_kt)/len(y_kt)
    
    ## update weights (in local)
    local_weights_pre = local_weights_cur
    ## calculate grad_phi_k and update local model
    for idx in idxs_users:
        # model
        model_cur = deepcopy(model)
        model_cur.load_state_dict(local_weights_cur[idx])
        
        # data loader for client idx
        train_loader_idx = dict_train_loaders[idx]
        # a random batch of train data for client idx
        batch = next(iter(train_loader_idx))
        
        x_k_new = update_x_k(model_cur = model_cur, 
                              y_t = y_t, 
                              batch = batch,
                              global_round = round,
                              lmbda = args.lamda, 
                              eta= args.curlr,
                              args=args,
                              verbose=True)
        
        # record the updated local weights
        local_weights_cur[idx] = deepcopy(x_k_new)
        
    ## update weights (in server)
    # update global weights
    
    if (round +1) % args.I == 0: #for each round we wil update global_weights
      weights = average_weights(local_weights_cur)
      # update global weights
      model.load_state_dict(weights)
      #com = com + 1


    # Snapshot iterate once in 1000 loops.
    # if e % 10 == 0:
    #   print(com)
    if round % 1000 == 0:
      train_data1 = deepcopy(train_data)
      test_data1 = deepcopy(test_data)
      train_data1.to(device)
      test_data1.to(device)
      model.eval()
      with torch.no_grad():
        y_pred_t = model(train_data1.data)
        err = error_rate(train_data1.targets, y_pred_t)
        # max_viol, viol_list = violation(
        #     train_data1.targets, y_pred_t, args.epsilon, train_data.proxy_groups_tensor)
        max_viol, viol_list = violation(
            train_data1.targets, y_pred_t, args.epsilon, train_data1.true_groups_tensor.T)

        y_pred_test =  model(test_data1.data)
        err_test = error_rate(test_data1.targets, y_pred_test)
        # max_viol_test, viol_list_test = violation(
        #     test_data.targets, y_pred_test, args.epsilon, test_data.proxy_groups_tensor)
        max_viol_test, viol_list_test = violation(
            test_data1.targets, y_pred_test, args.epsilon, test_data.true_groups_tensor.T)

        #tb_writer.add_scalar('Train/Loss',main_loss.item(),e)
        tb_writer.add_scalar('Train/Error', err, round)
        tb_writer.add_scalar('Test/Error',err_test, round)
        tb_writer.add_scalar('Train/Max_violation',max_viol, round)
        tb_writer.add_scalar('Test/Max_violation',max_viol_test, round)
        
        # if e % (1) == 0:
        print("Round %d/ %d | Epoch %d | Error = %.3f | Viol = %.3f | Viol_test = %.3f" %
              (round, args.rounds, epoch, err_test, max_viol, max_viol_test), flush=True)
  return

def train_fedavg(args, device=None):
  """Training model"""
  tb_writer = SummaryWriter(log_dir=args.log_dir)
  if device is None:
    device = torch.device('cuda')

  ## fix random seeds
  torch.manual_seed(args.random_seed)
  np.random.seed(args.random_seed)
  #criterion = nn.BCEWithLogitsLoss()
  train_data, test_data = load_data_adult(args.noise_level, 
                                          uniform_groups=args.uniform_groups, 
                                          min_group_frac=args.min_group_frac, 
                                          use_noise_array=args.use_noise_array,
                                          group_features_type=args.group_features_type, 
                                          num_group_clusters=args.num_group_clusters)

  #train_data.to(device)
  #test_data.to(device)

  args.rounds = len(train_data)//args.batch_size * args.epochs * args.I
  #print(args.rounds)
  args.local_bs = args.batch_size//args.num_users

  ##split dataset for different clinets (users) in FL setting
  idxs_users = list(range(args.num_users))
  # split the whole train dataset for each user
  dict_users = dataset_iid(dataset=train_data, num_users=args.num_users)
  # dict_train_loaders = [torch.utils.data.DataLoader(DatasetSplit(train_data, dict_users[idx]),
  #                                                   batch_size=args.local_bs, shuffle=True, 
  #                                                   num_workers=0, pin_memory=False) 
  #                         for idx in idxs_users]
   # no data split, each client can access the whole train dataset
  dict_train_loaders = [torch.utils.data.DataLoader(train_data, 
                                                    batch_size=args.batch_size, 
                                                    shuffle=True, 
                                                    num_workers=0, 
                                                    pin_memory=False)
                        for idx in idxs_users]
    
  y_kt = [0.0 for i in idxs_users]

  model = Linear(train_data.num_features).to(device)
  weights = model.state_dict()
  local_weights_pre = [deepcopy(weights) for i in range(len(idxs_users))]
  local_weights_cur = [deepcopy(weights) for i in range(len(idxs_users))]

  #optim =  torch.optim.Adagrad(model.parameters(), lr=args.lr)
 
  com = 0
  # FL trining loop with rounds
  for round in range(args.rounds):
    epoch = (round * args.batch_size) // len(train_data)  
    if args.resample_proxy_groups:
        # Only resample proxy groups every epochs_per_resample epochs.
        if round % args.epochs_per_resample == 0:
          # Resample the group at the beginning of the epoch.
          # Get groups_train from a ball around init_proxy_groups_train.
          if args.uniform_groups:
            train_data.proxy_groups_tensor = generate_proxy_groups_uniform(
                len(train_data), min_group_frac=args.min_group_frac).long().to(device)
          elif args.use_noise_array:
            train_data.noise_array = train_data.get_noise_array()
            train_data.proxy_groups_tensor = generate_proxy_groups_noise_array(
                train_data.proxy_groups_tensor, noise_array=train_data.noise_array).long().to(device)
          else:
            train_data.proxy_groups_tensor = generate_proxy_groups_single_noise(
                train_data.proxy_groups_tensor, noise_param=args.noise_level).long().to(device)
    if round == 0:
      epoch_old = epoch
    #adjust_curlr_beta(epoch, args)
    #print(args.curlr)
    model.train()
    
    ## calculate grad_phi_k and update local model
    for idx in idxs_users:
        # model
        model_cur = deepcopy(model)
        model_cur.load_state_dict(local_weights_cur[idx])
        
        # data loader for client idx
        train_loader_idx = dict_train_loaders[idx]
        # a random batch of train data for client idx
        batch = next(iter(train_loader_idx))
        
        x_k_new = update_x_k_fedavg(model_cur = model_cur, 
                                 batch = batch,
                                 global_round = round,
                                 lmbda = args.lamda, 
                                 eta= args.curlr,
                                 args=args,
                                 verbose=True)
        
        # record the updated local weights
        local_weights_cur[idx] = deepcopy(x_k_new)
        
    ## update weights (in server)
    # update global weights
    
    if (round +1) % args.I == 0: #for each round we wil update global_weights
      weights = average_weights(local_weights_cur)
      # update global weights
      model.load_state_dict(weights)
      #com = com + 1


    # Snapshot iterate once in 1000 loops.
    # if e % 10 == 0:
    #   print(com)
    if round % 1000 == 0:
      train_data1 = deepcopy(train_data)
      test_data1 = deepcopy(test_data)
      train_data1.to(device)
      test_data1.to(device)
      model.eval()
      with torch.no_grad():
        y_pred_t = model(train_data1.data)
        err = error_rate(train_data1.targets, y_pred_t)
        # max_viol, viol_list = violation(
        #     train_data1.targets, y_pred_t, args.epsilon, train_data.proxy_groups_tensor)
        max_viol, viol_list = violation(
            train_data1.targets, y_pred_t, args.epsilon, train_data1.true_groups_tensor.T)

        y_pred_test =  model(test_data1.data)
        err_test = error_rate(test_data1.targets, y_pred_test)
        # max_viol_test, viol_list_test = violation(
        #     test_data.targets, y_pred_test, args.epsilon, test_data.proxy_groups_tensor)
        max_viol_test, viol_list_test = violation(
            test_data1.targets, y_pred_test, args.epsilon, test_data.true_groups_tensor.T)

        #tb_writer.add_scalar('Train/Loss',main_loss.item(),e)
        tb_writer.add_scalar('Train/Error', err, round)
        tb_writer.add_scalar('Test/Error',err_test, round)
        tb_writer.add_scalar('Train/Max_violation',max_viol, round)
        tb_writer.add_scalar('Test/Max_violation',max_viol_test, round)
        
        # if e % (1) == 0:
        print("Round %d/ %d | Epoch %d | Error = %.3f | Viol = %.3f | Viol_test = %.3f" %
              (round, args.rounds, epoch, err_test, max_viol, max_viol_test), flush=True)
  return


def train_unconstrained(args, device=None):
  """Training model"""
  tb_writer = SummaryWriter(log_dir=args.log_dir)
  if device is None:
    device = torch.device('cuda')

  ## fix random seeds
  torch.manual_seed(args.random_seed)
  np.random.seed(args.random_seed)
  criterion = nn.BCEWithLogitsLoss()
  train_data, test_data = load_data_adult(args.noise_level, 
                                          uniform_groups=args.uniform_groups, 
                                          min_group_frac=args.min_group_frac, 
                                          use_noise_array=args.use_noise_array,
                                          group_features_type=args.group_features_type, 
                                          num_group_clusters=args.num_group_clusters)

  train_data.to(device)
  test_data.to(device)

  args.rounds = len(train_data)//args.batch_size * args.epochs * args.I
  args.local_bs = args.batch_size//args.num_users
  print(args.rounds)

  ##split dataset for different clinets (users) in FL setting
  idxs_users = list(range(args.num_users))
  # split the whole train dataset for each user
  dict_users = dataset_iid(dataset=train_data, num_users=args.num_users)
  dict_train_loaders = [torch.utils.data.DataLoader(DatasetSplit(train_data, dict_users[idx]),
                                                    batch_size=args.local_bs, shuffle=True, 
                                                    num_workers=0, pin_memory=False) 
                          for idx in idxs_users]

  model = Linear(train_data.num_features).to(device)
  weights = model.state_dict()
  # local_weights_pre = [deepcopy(weights) for i in range(len(idxs_users))]
  local_weights_cur = [deepcopy(weights) for i in range(len(idxs_users))]

  #optim =  torch.optim.Adagrad(model.parameters(), lr=args.lr)
 
  com = 0
  for round in range(args.rounds):  
    #adjust_curlr_beta(epoch, args, optimizer=None)
    epoch = (round * args.batch_size) // len(train_data) 
    model.train()
    # for local updates
    for idx in idxs_users:
        # models
      # model_pre = deepcopy(model)
      # model_pre.load_state_dict(local_weights_pre[idx])
      model_cur = deepcopy(model)
      model_cur.load_state_dict(local_weights_cur[idx])
      optim_cur =  torch.optim.Adagrad(model_cur.parameters(), lr=args.lr)
      #optim_cur = torch.optim.SGD(model_cur.parameters(), lr=args.lr, momentum=0.5)
      #optim_cur = torch.optim.Adam(model_cur.parameters(), lr=args.lr, weight_decay=1e-4)

      if args.resample_proxy_groups:
        # Only resample proxy groups every epochs_per_resample epochs.
        if round % args.epochs_per_resample == 0:
          # Resample the group at the beginning of the epoch.
          # Get groups_train from a ball around init_proxy_groups_train.
          if args.uniform_groups:
            train_data.proxy_groups_tensor = generate_proxy_groups_uniform(
                len(train_data), min_group_frac=args.min_group_frac).long().to(device)
          elif args.use_noise_array:
            train_data.noise_array = train_data.get_noise_array()
            train_data.proxy_groups_tensor = generate_proxy_groups_noise_array(
                train_data.proxy_groups_tensor, noise_array=train_data.noise_array).long().to(device)
          else:
            train_data.proxy_groups_tensor = generate_proxy_groups_single_noise(
                train_data.proxy_groups_tensor, noise_param=args.noise_level).long().to(device)
            
      # data loader for client idx
      train_loader_idx = dict_train_loaders[idx]
      model_cur.train()
      # model_pre.train()
      # a random batch of train data for client idx
      batch = next(iter(train_loader_idx))
      images, labels = batch
      images, labels = images.to(device), labels.to(device)
      
      optim_cur.zero_grad()
      y_pred_idx = model_cur(images)
      main_loss = criterion(y_pred_idx, labels)


      main_loss.backward()
      optim_cur.step()

      # local_weights_pre = local_weights_cur
      local_weights_cur[idx] = deepcopy(model_cur.state_dict())
      #model_pre.load_state_dict(model.state_dict().copy())
          
      
      ## update weights (in server)
      # update global weights
      # if (epoch +1) % args.I == 0: for each round we wil update global_weights
    if (round +1) % args.I == 0: #for each round we wil update global_weights
      weights = average_weights(local_weights_cur) 
      # update global weights
      model.load_state_dict(weights)
      com = com + 1


    # Snapshot iterate once in 10 loops.
    if round % 1000 == 0:
      model.eval()
      with torch.no_grad():
        y_pred_t = model(train_data.data)
        err = error_rate(train_data.targets, y_pred_t)
        # max_viol, viol_list = violation(
        #     train_data.targets, y_pred_t, args.epsilon, train_data.proxy_groups_tensor)
        max_viol, viol_list = violation(
            train_data.targets, y_pred_t, args.epsilon, train_data.true_groups_tensor.T)

        y_pred_test =  model(test_data.data)
        err_test = error_rate(test_data.targets, y_pred_test)
        # max_viol_test, viol_list_test = violation(
        #     test_data.targets, y_pred_test, args.epsilon, test_data.proxy_groups_tensor)
        max_viol_test, viol_list_test = violation(
            test_data.targets, y_pred_test, args.epsilon, test_data.true_groups_tensor.T)

        tb_writer.add_scalar('Train/Loss',main_loss.item(),round)
        tb_writer.add_scalar('Train/Error', err,round)
        tb_writer.add_scalar('Test/Error',err_test,round)
        tb_writer.add_scalar('Train/Max_violation',max_viol,round)
        tb_writer.add_scalar('Test/Max_violation',max_viol_test,round)
        
        
        print("Round %d/ %d | Epoch %d | Error = %.3f | Viol = %.3f | Viol_test = %.3f" %
              (round, args.rounds, epoch, err_test, max_viol, max_viol_test), flush=True)
  return
def train_constrained(args, device=None):
  """Traqining model"""
  tb_writer = SummaryWriter(log_dir=args.log_dir)
  if device is None:
    device = torch.device('cpu')

  criterion = nn.BCEWithLogitsLoss()
  train_data, test_data = load_data_adult(args.noise_level, 
                                          uniform_groups=args.uniform_groups, 
                                          min_group_frac=args.min_group_frac, 
                                          use_noise_array=args.use_noise_array,
                                          group_features_type=args.group_features_type, 
                                          num_group_clusters=args.num_group_clusters)

  train_data.to(device)
  test_data.to(device)

  model = Linear(train_data.num_features)
  model_last = deepcopy(model).to(device)

  optimizers = {'main': torch.optim.Adagrad(model.parameters(), lr=args.lr),
                'last': torch.optim.Adagrad(model_last.parameters(), lr=args.lr)
  }
  const_losses = {'main': torch.zeros(1), 'last': torch.zeros(1)}
 

  for e in range(args.epochs):
    # Check for the beginning of a new epoch.
    if args.resample_proxy_groups and args.constrained:

      # Only resample proxy groups every epochs_per_resample epochs.
      if e % args.epochs_per_resample == 0:
        # Resample the group at the beginning of the epoch.
        # Get groups_train from a ball around init_proxy_groups_train.
        if args.uniform_groups:
          train_data.proxy_groups_tensor = generate_proxy_groups_uniform(
              len(train_data), min_group_frac=args.min_group_frac).long().to(device)
        elif args.use_noise_array:
          train_data.noise_array = train_data.get_noise_array()
          train_data.proxy_groups_tensor = generate_proxy_groups_noise_array(
              train_data.proxy_groups_tensor, noise_array=train_data.noise_array).long().to(device)
        else:
          train_data.proxy_groups_tensor = generate_proxy_groups_single_noise(
              train_data.proxy_groups_tensor, noise_param=args.noise_level).long().to(device)
    model.train()
    model_last.train()
    optimizers['main'].zero_grad()
    y_pred = model(train_data.data)
    # loss = criterion(y_pred, (train_data.target-.5)*2)
    main_loss = criterion(y_pred, train_data.targets)
    const_losses['main'] = lagrangian_loss(train_data.targets, y_pred, args.epsilon, train_data.proxy_groups_tensor,
                                  criterion, device, gamma=args.dual_scale) 
    y_pred_last = model_last(train_data.data)
    const_losses['last'] = lagrangian_loss(train_data.targets, y_pred_last, args.epsilon, train_data.proxy_groups_tensor,
                                  criterion, device, gamma=args.dual_scale) 
    model.backward_dro(main_loss, const_losses, optimizers, model_last)
    model_last.load_state_dict(model.state_dict().copy())
    optimizers['main'].step()

    # Snapshot iterate once in 1000 loops.
    if e % 10 == 0:
      model.eval()
      with torch.no_grad():
        y_pred_t = model(train_data.data)
        err = error_rate(train_data.targets, y_pred_t)
        max_viol, viol_list = violation(
            train_data.targets, y_pred_t, args.epsilon, train_data.true_groups_tensor.T)

        y_pred_test =  model(test_data.data)
        err_test = error_rate(test_data.targets, y_pred_test)
        max_viol_test, viol_list_test = violation(
            test_data.targets, y_pred_test, args.epsilon, test_data.true_groups_tensor.T)

        #tb_writer.add_scalar('Train/Loss',main_loss.item(),e)
        tb_writer.add_scalar('Train/Accuracy', 1-err, e)
        tb_writer.add_scalar('Test/Accuracy',1-err_test, e)
        tb_writer.add_scalar('Train/Max_violation',max_viol, e)
        tb_writer.add_scalar('Test/Max_violation',max_viol_test, e)
        
        print("Epoch %d | Train accuracy = %.3f | Test accuracy = %.3f | Viol = %.3f | Viol_test = %.3f" %
              (e, 1-err, 1-err_test, max_viol, max_viol_test), flush=True)
  return


if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-dro', '--DDRO', action='store_true')
    parser.add_argument('-con', '--constrained', action='store_true')
    parser.add_argument('-fed', '--fedavg', action='store_true')
    parser.add_argument('-e', '--epsilon', default=0.01, type=float) 
    parser.add_argument('-b', '--batch_size', default=100, type=int)
    parser.add_argument('-l', '--lr', default=0.01, type=float)
    parser.add_argument('-i', '--iterations',default=50000, type=int)
    parser.add_argument('-p', '--epochs',default=500, type=int)
    parser.add_argument('-o', '--log',default=100, type=int)
    parser.add_argument('-g', '--dual_scale', default=0.1, type=float)
    parser.add_argument('-d', '--logdir', default='/home/rafi/GCIVR', type=str)
    parser.add_argument('-n', '--noise_level', default=0.3, type=float)
    parser.add_argument('-un', '--use_noise_array', action='store_true')
    parser.add_argument('-ug', '--uniform_groups', action='store_true')
    parser.add_argument('-gf', '--min_group_frac', default=0.01, type=float) 
    parser.add_argument('-ft', '--group_features_type', default='full_group_vec', type=str)
    parser.add_argument('-nc', '--num_group_clusters',default=100, type=int)
    parser.add_argument('-rp', '--resample_proxy_groups', action='store_true')
    parser.add_argument('-er', '--epochs_per_resample', default=1, type=int)
    parser.add_argument('-f', '--full_step',default=100, type=int)
    parser.add_argument('-k', '--num_users',default=10, type=int)
    parser.add_argument('-rs', '--random_seed', default=3, type=int)
    parser.add_argument('--curbeta', default=0.1, type=float, help='current learning rate')
    parser.add_argument('--beta', default=0.6, type = float, help = 'momentum parameters for SCCMA')
    parser.add_argument('--curlr', default=0.01, type=float, help='current learning rate')
    parser.add_argument('--lamda', default=5, type = float, help = 'parameters of regularization')
    parser.add_argument('--local_ep', default=1, type=int, help='local_epoch for update')
    parser.add_argument('--I', default=1, type=int, help='the frequency for FL communication to updata global model')
    #parser.add_argument('-ro', '--round',default=300, type=int)

    args = parser.parse_args()

    args.log_dir =  increment_dir(Path(args.logdir) / 'exp')
    os.makedirs(args.log_dir)
    yaml_file = str(Path(args.log_dir) / "args.yaml")
    with open(yaml_file, 'w') as out:
      yaml.dump(args.__dict__, out, default_flow_style=False)
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if args.DDRO:
      train_DDRO(args)
    elif args.constrained:
      train_constrained(args) 
    elif args.fedavg:
      train_fedavg(args)
    
    else:
      train_unconstrained(args)
    # train_DDRO(args)