# coding=utf-8
"""Fairness with noisy protected groups experiments."""
import argparse
import os
import yaml
from pathlib import Path

import numpy as np
from sklearn.cluster import KMeans
import torch
import torch.nn as nn
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter

from utils import generate_proxy_groups_uniform, generate_proxy_groups_noise_array,\
                  generate_proxy_groups_single_noise, error_rate, group_error_rates, \
                  tpr, group_tprs,increment_dir, dataset_iid
from data import load_data_adult
from model import Linear
from update import adjust_curlr, LocalUpdateDDRO, FedAvg, DatasetSplit
from torch.utils.data import DataLoader, Dataset
from collections import OrderedDict

def violation(
    labels, predictions, epsilon, groups):
  # Returns violations across different group feature thresholds.
  viol_list = []
  overall_tpr = tpr(labels, predictions).cpu()
  for kk in range(groups.shape[1]):
    group_tpr =group_tprs(labels, predictions, groups[:,kk]).cpu()
    viol_list += [overall_tpr - group_tpr - epsilon]
  return np.max(viol_list), viol_list

def lagrangian_loss(labels, predictions, epsilon, 
                    groups, criterion, device, gamma=0.1, alpha=0.1):
  # Returns the lagrangain loss of the constrained problem
  total_loss = torch.zeros(1).to(device)
  main_loss = criterion(predictions, labels)
  m = groups.shape[1]
  for kk in range(m):
    g_inds = groups[:,kk].logical_and((labels >= 1.0).squeeze())
    group_loss = criterion(predictions[g_inds],labels[g_inds])
    const_loss = group_loss - main_loss - epsilon
    total_loss += 1/(m+1) * (1/m + torch.exp(const_loss * alpha / gamma))
  return total_loss

def train_DSDDRO(args, device=None):
  print("DS-FedDRO")
  print("GPU: %d, clients: %d, batch size: %d, lambda: %f, I: %d, local batch: %d, gamma_x: %lf, gamma_y: %lf" %(args.GPU, args.num_users, args.local_bs, args.lmbda, args.I, args.local_bs, args.gamma_x, args.gamma_y))
  """Training model"""
  tb_writer = SummaryWriter(log_dir=args.log_dir)
  if device is None:
    device = torch.device('cuda')

  ## fix random seeds
  torch.manual_seed(args.random_seed)
  np.random.seed(args.random_seed)
  #criterion = nn.BCEWithLogitsLoss()
  dataset_train, dataset_test = load_data_adult(args.noise_level, 
                                          uniform_groups=args.uniform_groups, 
                                          min_group_frac=args.min_group_frac, 
                                          use_noise_array=args.use_noise_array,
                                          group_features_type=args.group_features_type, 
                                          num_group_clusters=args.num_group_clusters)

  # split the whole train dataset for each user
  dict_users = dataset_iid(dataset=dataset_train, num_users=args.num_users)
  # for k, b in dict_users.items():
  #   print(k, len(b))

  net_glob = Linear(dataset_train.num_features).to(device)
  # print("length of model:", len(net_glob.state_dict().keys()))

 # begin training
  net_glob.train()
  # copy weights
  w_glob = net_glob.state_dict()
  w_glob_prev = deepcopy(w_glob)

  #Aggregation over all clients
  w_locals = [w_glob for i in range(args.num_users)]
  y_kt = [0.0 for i in range(args.num_users)]
  y_t = sum(y_kt)/len(y_kt)
  epochs = max(args.com * args.I, args.epochs)
  com = 0
  for iter in range(epochs):
    if args.resample_proxy_groups:
    # Only resample proxy groups every epochs_per_resample epochs.
      if iter % args.epochs_per_resample == 0:
        # Resample the group at the beginning of the epoch.
        # Get groups_train from a ball around init_proxy_groups_train.
        if args.uniform_groups:
          dataset_train.proxy_groups_tensor = generate_proxy_groups_uniform(
              len(dataset_train), min_group_frac=args.min_group_frac).long().to(device)
        elif args.use_noise_array:
          dataset_train.noise_array = dataset_train.get_noise_array()
          dataset_train.proxy_groups_tensor = generate_proxy_groups_noise_array(
              dataset_train.proxy_groups_tensor, noise_array=dataset_train.noise_array).long().to(device)
        else:
          dataset_train.proxy_groups_tensor = generate_proxy_groups_single_noise(
              dataset_train.proxy_groups_tensor, noise_param=args.noise_level).long().to(device)
              
    adjust_curlr(iter, args, optimizer=None)
    
    loss_locals = []
    m = args.num_users
    idxs_users = np.random.choice(range(args.num_users), args.num_users, replace=False)


    #print("Number of users: ",idxs_users)
    # update y_kt
    for idx in idxs_users:
        local = LocalUpdateDDRO(args=args, dataset=dataset_train, idxs=dict_users[idx])
        y_kt[idx] = local.update_ykt(deepcopy(net_glob).to(device), 
                          w_glob, w_glob_prev, ykt=y_kt[idx])
    # get the y_t
    if (iter + 1) % args.I == 0:  # After every I local updates
      diff = 0  # Initialize the difference variable

      # Loop through all clients and calculate the difference
      for idx in range(len(y_kt)):
          diff += (y_t - y_kt[idx])
      # y_t = sum(y_kt)/len(y_kt)
      # Average the difference over all clients
      diff = diff / len(y_kt)

      # Update the global weights using the averaged difference and the step size (gamma_x)
      y_t = y_t - args.gamma_y * diff
      # Distribute the updated y_t to all clients
      for idx in range(len(y_kt)):
          y_kt[idx] = y_t  # Update each local y_kt[idx] to the new global y_t

    # local update 
    for idx in idxs_users:
        local = LocalUpdateDDRO(args=args, dataset=dataset_train, idxs=dict_users[idx])
        w, loss = local.train(net=deepcopy(net_glob), y_t=y_kt[idx])
        
        w_locals[idx] = deepcopy(w)
        loss_locals.append(deepcopy(loss))
    if (iter + 1) % args.I == 0:  # After every I local updates
      com = com+1
    # update global weights
      diff = OrderedDict()

      # Initialize diff with zeros
      for key in w_glob.keys():
          diff[key] = torch.zeros_like(w_glob[key])

      # Loop through all clients and calculate the difference for each parameter
      for idx in range(len(w_locals)):
          for key in w_glob.keys():
              diff[key] += (w_glob[key] - w_locals[idx][key])

      # Average the difference over all clients
      for key in diff.keys():
          diff[key] = diff[key] / len(w_locals)

      # Update the global weights using the averaged difference and the step size (gamma_x)
      for key in w_glob.keys():
          w_glob[key] = w_glob[key] - args.gamma_x * diff[key]
              
      # w_glob = FedAvg(w_locals)

      # copy weight to net_glob
      net_glob.load_state_dict(w_glob)

      if com % 10 == 0:
        train_data1 = deepcopy(dataset_train)
        test_data1 = deepcopy(dataset_test)
        train_data1.to(device)
        test_data1.to(device)
        net_glob.eval()
        with torch.no_grad():
          y_pred_t = net_glob(train_data1.data)
          err = error_rate(train_data1.targets, y_pred_t)
          max_viol, viol_list = violation(
              train_data1.targets, y_pred_t, args.epsilon, train_data1.true_groups_tensor.T)

          y_pred_test =  net_glob(test_data1.data)
          err_test = error_rate(test_data1.targets, y_pred_test)

          max_viol_test, viol_list_test = violation(
              test_data1.targets, y_pred_test, args.epsilon, test_data1.true_groups_tensor.T)

          #tb_writer.add_scalar('Train/Loss',main_loss.item(),e)
          tb_writer.add_scalar('Train/Accuracy', 1-err, com)
          tb_writer.add_scalar('Test/Accuracy',1-err_test, com)
          tb_writer.add_scalar('Train/Max_violation',max_viol, com)
          tb_writer.add_scalar('Test/Max_violation',max_viol_test, com)
          
          print("Epoch %d| Comm rounds %d | Train accuracy = %.3f | Test accuracy = %.3f  Viol = %.3f | Viol_test = %.3f" %
                (iter, com, 1-err, 1-err_test, max_viol, max_viol_test), flush=True)
  return

def train_DDRO(args, device=None):
  print("FedDRO")
  print("GPU: %d, clients: %d, batch size: %d, lambda: %f, I: %d, local batch: %d, gamma_x: %lf, gamma_y: %lf" %(args.GPU, args.num_users, args.local_bs, args.lmbda, args.I, args.local_bs, args.gamma_x, args.gamma_y))
  """Training model"""
  
  print(args.log_dir)
  tb_writer = SummaryWriter(log_dir=args.log_dir)
  if device is None:
    device = torch.device('cuda')

  ## fix random seeds
  torch.manual_seed(args.random_seed)
  np.random.seed(args.random_seed)
  #criterion = nn.BCEWithLogitsLoss()
  dataset_train, dataset_test = load_data_adult(args.noise_level, 
                                          uniform_groups=args.uniform_groups, 
                                          min_group_frac=args.min_group_frac, 
                                          use_noise_array=args.use_noise_array,
                                          group_features_type=args.group_features_type, 
                                          num_group_clusters=args.num_group_clusters)

  # split the whole train dataset for each user
  dict_users = dataset_iid(dataset=dataset_train, num_users=args.num_users)
  # for k, b in dict_users.items():
  #   print(k, len(b))

  net_glob = Linear(dataset_train.num_features).to(device)
  # print("length of model:", len(net_glob.state_dict().keys()))

 # begin training
  net_glob.train()
  # copy weights
  w_glob = net_glob.state_dict()
  w_glob_prev = deepcopy(w_glob)

  #Aggregation over all clients
  w_locals = [w_glob for i in range(args.num_users)]
  y_kt = [0.0 for i in range(args.num_users)]

  epochs = max(args.com * args.I, args.epochs)
  com = 0
  for iter in range(epochs):
    if args.resample_proxy_groups:
    # Only resample proxy groups every epochs_per_resample epochs.
      if iter % args.epochs_per_resample == 0:
        # Resample the group at the beginning of the epoch.
        # Get groups_train from a ball around init_proxy_groups_train.
        if args.uniform_groups:
          dataset_train.proxy_groups_tensor = generate_proxy_groups_uniform(
              len(dataset_train), min_group_frac=args.min_group_frac).long().to(device)
        elif args.use_noise_array:
          dataset_train.noise_array = dataset_train.get_noise_array()
          dataset_train.proxy_groups_tensor = generate_proxy_groups_noise_array(
              dataset_train.proxy_groups_tensor, noise_array=dataset_train.noise_array).long().to(device)
        else:
          dataset_train.proxy_groups_tensor = generate_proxy_groups_single_noise(
              dataset_train.proxy_groups_tensor, noise_param=args.noise_level).long().to(device)
              
    adjust_curlr(iter, args, optimizer=None)
    
    loss_locals = []
    m = args.num_users
    idxs_users = np.random.choice(range(args.num_users), args.num_users, replace=False)


    #print("Number of users: ",idxs_users)
    # update y_kt
    for idx in idxs_users:
        local = LocalUpdateDDRO(args=args, dataset=dataset_train, idxs=dict_users[idx])
        y_kt[idx] = local.update_ykt(deepcopy(net_glob).to(device), 
                          w_glob, w_glob_prev, ykt=y_kt[idx])
    # get the y_t
    y_t = sum(y_kt)/len(y_kt)
    # y_t = 0
    #writer.add_scalar("y_t", y_t, iter)
    # local update 
    for idx in idxs_users:
        local = LocalUpdateDDRO(args=args, dataset=dataset_train, idxs=dict_users[idx])
        w, loss = local.train(net=deepcopy(net_glob), y_t=y_t)
        
        w_locals[idx] = deepcopy(w)
        loss_locals.append(deepcopy(loss))
    if (iter + 1) % args.I == 0:  # After every I local updates
      com = com+1
      # update global weights
      w_glob = FedAvg(w_locals)

      # copy weight to net_glob
      net_glob.load_state_dict(w_glob)

    if com % 10 == 0:
      train_data1 = deepcopy(dataset_train)
      test_data1 = deepcopy(dataset_test)
      train_data1.to(device)
      test_data1.to(device)
      net_glob.eval()
      with torch.no_grad():
        y_pred_t = net_glob(train_data1.data)
        err = error_rate(train_data1.targets, y_pred_t)
        max_viol, viol_list = violation(
            train_data1.targets, y_pred_t, args.epsilon, train_data1.true_groups_tensor.T)

        y_pred_test =  net_glob(test_data1.data)
        err_test = error_rate(test_data1.targets, y_pred_test)

        max_viol_test, viol_list_test = violation(
            test_data1.targets, y_pred_test, args.epsilon, test_data1.true_groups_tensor.T)

        #tb_writer.add_scalar('Train/Loss',main_loss.item(),e)
        tb_writer.add_scalar('Train/Accuracy', 1-err, com)
        tb_writer.add_scalar('Test/Accuracy',1-err_test, com)
        tb_writer.add_scalar('Train/Max_violation',max_viol, com)
        tb_writer.add_scalar('Test/Max_violation',max_viol_test, com)
        
        print("Epoch %d| Comm rounds %d | Train accuracy = %.3f | Test accuracy = %.3f  Viol = %.3f | Viol_test = %.3f" %
              (iter, com, 1-err, 1-err_test, max_viol, max_viol_test), flush=True)
  return

def train_fedavg(args, device=None):
  print("FedAVG")
  print("GPU: %d, clients: %d, batch size: %d, lambda: %f, I: %d, local batch: %d" %(args.GPU, args.num_users, args.local_bs, args.lmbda, args.I, args.local_bs))
  """Training model"""
  tb_writer = SummaryWriter(log_dir=args.log_dir)
  if device is None:
    device = torch.device('cuda')

  ## fix random seeds
  torch.manual_seed(args.random_seed)
  np.random.seed(args.random_seed)
  #criterion = nn.BCEWithLogitsLoss()
  dataset_train, dataset_test = load_data_adult(args.noise_level, 
                                          uniform_groups=args.uniform_groups, 
                                          min_group_frac=args.min_group_frac, 
                                          use_noise_array=args.use_noise_array,
                                          group_features_type=args.group_features_type, 
                                          num_group_clusters=args.num_group_clusters)

  # split the whole train dataset for each user
  dict_users = dataset_iid(dataset=dataset_train, num_users=args.num_users)
  # for k, b in dict_users.items():
  #   print(k, len(b))

  net_glob = Linear(dataset_train.num_features).to(device)
  # print("length of model:", len(net_glob.state_dict().keys()))

 # begin training
  net_glob.train()
  # copy weights
  w_glob = net_glob.state_dict()

  #Aggregation over all clients
  w_locals = [w_glob for i in range(args.num_users)]

  for iter in range(args.epochs):
    if args.resample_proxy_groups:
    # Only resample proxy groups every epochs_per_resample epochs.
      if iter % args.epochs_per_resample == 0:
        # Resample the group at the beginning of the epoch.
        # Get groups_train from a ball around init_proxy_groups_train.
        if args.uniform_groups:
          dataset_train.proxy_groups_tensor = generate_proxy_groups_uniform(
              len(dataset_train), min_group_frac=args.min_group_frac).long().to(device)
        elif args.use_noise_array:
          dataset_train.noise_array = dataset_train.get_noise_array()
          dataset_train.proxy_groups_tensor = generate_proxy_groups_noise_array(
              dataset_train.proxy_groups_tensor, noise_array=dataset_train.noise_array).long().to(device)
        else:
          dataset_train.proxy_groups_tensor = generate_proxy_groups_single_noise(
              dataset_train.proxy_groups_tensor, noise_param=args.noise_level).long().to(device)
              
    adjust_curlr(iter, args, optimizer=None)
    
    loss_locals = []
    # m = args.num_users
    idxs_users = np.random.choice(range(args.num_users), args.num_users, replace=False)


    for idx in idxs_users:
        local = LocalUpdateFedAVG(args=args, dataset=dataset_train, idxs=dict_users[idx])
        w, loss = local.train(net=deepcopy(net_glob))
        
        w_locals[idx] = deepcopy(w)
        loss_locals.append(loss)
    # update global weights
    w_glob = FedAvg(w_locals)

    # copy weight to net_glob
    net_glob.load_state_dict(w_glob)

    if iter % 100 == 0:
      train_data1 = deepcopy(dataset_train)
      test_data1 = deepcopy(dataset_test)
      train_data1.to(device)
      test_data1.to(device)
      net_glob.eval()
      with torch.no_grad():
        y_pred_t = net_glob(train_data1.data)
        err = error_rate(train_data1.targets, y_pred_t)
        max_viol, viol_list = violation(
            train_data1.targets, y_pred_t, args.epsilon, train_data1.true_groups_tensor.T)

        y_pred_test =  net_glob(test_data1.data)
        err_test = error_rate(test_data1.targets, y_pred_test)

        max_viol_test, viol_list_test = violation(
            test_data1.targets, y_pred_test, args.epsilon, test_data1.true_groups_tensor.T)

        #tb_writer.add_scalar('Train/Loss',main_loss.item(),e)
        tb_writer.add_scalar('Train/Accuracy', 1-err, iter)
        tb_writer.add_scalar('Test/Accuracy',1-err_test, iter)
        tb_writer.add_scalar('Train/Max_violation',max_viol, iter)
        tb_writer.add_scalar('Test/Max_violation',max_viol_test, iter)
        
        print("Epoch %d | Train accuracy = %.3f | Test accuracy = %.3f  Viol = %.3f | Viol_test = %.3f" %
              (iter, 1-err, 1-err_test, max_viol, max_viol_test), flush=True)
  return


def train_unconstrained(args, device=None):
  """Training model"""
  print("UNCONSTRAINED")
  tb_writer = SummaryWriter(log_dir=args.log_dir)
  if device is None:
    device = torch.device('cuda')

  ## fix random seeds
  torch.manual_seed(args.random_seed)
  np.random.seed(args.random_seed)
  #criterion = nn.BCEWithLogitsLoss()
  dataset_train, dataset_test = load_data_adult(args.noise_level, 
                                          uniform_groups=args.uniform_groups, 
                                          min_group_frac=args.min_group_frac, 
                                          use_noise_array=args.use_noise_array,
                                          group_features_type=args.group_features_type, 
                                          num_group_clusters=args.num_group_clusters)

  # split the whole train dataset for each user
  dict_users = dataset_iid(dataset=dataset_train, num_users=args.num_users)

  net_glob = Linear(dataset_train.num_features).to(device)

 # begin training
  net_glob.train()
  # copy weights
  w_glob = net_glob.state_dict()

  #Aggregation over all clients
  w_locals = [w_glob for i in range(args.num_users)]
  criterion = nn.BCEWithLogitsLoss()

  for iter in range(args.epochs):
    if args.resample_proxy_groups:
    # Only resample proxy groups every epochs_per_resample epochs.
      if iter % args.epochs_per_resample == 0:
        # Resample the group at the beginning of the epoch.
        # Get groups_train from a ball around init_proxy_groups_train.
        if args.uniform_groups:
          dataset_train.proxy_groups_tensor = generate_proxy_groups_uniform(
              len(dataset_train), min_group_frac=args.min_group_frac).long().to(device)
        elif args.use_noise_array:
          dataset_train.noise_array = dataset_train.get_noise_array()
          dataset_train.proxy_groups_tensor = generate_proxy_groups_noise_array(
              dataset_train.proxy_groups_tensor, noise_array=dataset_train.noise_array).long().to(device)
        else:
          dataset_train.proxy_groups_tensor = generate_proxy_groups_single_noise(
              dataset_train.proxy_groups_tensor, noise_param=args.noise_level).long().to(device)
    
    loss_locals = []
    m = 1
    idxs_users = np.random.choice(range(args.num_users), m, replace=False) 
    for idx in idxs_users:
        ldr_train = DataLoader(DatasetSplit(dataset_train, dict_users[idx]), batch_size=args.local_bs, shuffle=True)
        net=deepcopy(net_glob)
        net.load_state_dict(w_locals[idx])
        optim_cur =  torch.optim.Adagrad(net.parameters(), lr=args.lr)

        net.train()
        count = 0
        epoch_loss = []
        for ep in range(args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(ldr_train):
                images, labels = images.to(device), labels.to(device)
                net.zero_grad()
                log_probs= net(images)
                main_loss = criterion(log_probs, labels)
                main_loss.backward()
                optim_cur.step()
                batch_loss.append(main_loss.item())
                count += 1
                if count >= args.I:
                  break
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
            if count >= args.I:
                  break
        loss = sum(epoch_loss) / len(epoch_loss)
        w = net.state_dict()
        w_locals[idx] = deepcopy(w)
        loss_locals.append(deepcopy(loss))
    # update global weights
    w_glob = FedAvg(w_locals)

    # copy weight to net_glob
    net_glob.load_state_dict(w_glob)

    if iter % 10 == 0:
      train_data1 = deepcopy(dataset_train)
      test_data1 = deepcopy(dataset_test)
      train_data1.to(device)
      test_data1.to(device)
      net_glob.eval()
      with torch.no_grad():
        y_pred_t = net_glob(train_data1.data)
        err = error_rate(train_data1.targets, y_pred_t)
        max_viol, viol_list = violation(
            train_data1.targets, y_pred_t, args.epsilon, train_data1.true_groups_tensor.T)

        y_pred_test =  net_glob(test_data1.data)
        err_test = error_rate(test_data1.targets, y_pred_test)

        max_viol_test, viol_list_test = violation(
            test_data1.targets, y_pred_test, args.epsilon, test_data1.true_groups_tensor.T)

        #tb_writer.add_scalar('Train/Loss',main_loss.item(),e)
        tb_writer.add_scalar('Train/Accuracy', 1-err, iter)
        tb_writer.add_scalar('Test/Accuracy',1-err_test, iter)
        tb_writer.add_scalar('Train/Max_violation',max_viol, iter)
        tb_writer.add_scalar('Test/Max_violation',max_viol_test, iter)
        
        print("Epoch %d | Train accuracy = %.3f | Test accuracy = %.3f | Viol = %.3f | Viol_test = %.3f" %
              (iter, 1-err, 1-err_test, max_viol, max_viol_test), flush=True)
  return

def train_constrained(args, device=None):
  """Traqining model"""
  print("GCIVR - CONSTRAINED")
  tb_writer = SummaryWriter(log_dir=args.log_dir)
  if device is None:
    device = torch.device('cpu')

  criterion = nn.BCEWithLogitsLoss()
  train_data, test_data = load_data_adult(args.noise_level, 
                                          uniform_groups=args.uniform_groups, 
                                          min_group_frac=args.min_group_frac, 
                                          use_noise_array=args.use_noise_array,
                                          group_features_type=args.group_features_type, 
                                          num_group_clusters=args.num_group_clusters)

  train_data.to(device)
  test_data.to(device)

  model = Linear(train_data.num_features)
  model_last = deepcopy(model).to(device)

  optimizers = {'main': torch.optim.Adagrad(model.parameters(), lr=args.lr),
                'last': torch.optim.Adagrad(model_last.parameters(), lr=args.lr)
  }
  const_losses = {'main': torch.zeros(1), 'last': torch.zeros(1)}
 

  for e in range(args.epochs):
    # Check for the beginning of a new epoch.
    if args.resample_proxy_groups and args.constrained:

      # Only resample proxy groups every epochs_per_resample epochs.
      if e % args.epochs_per_resample == 0:
        # Resample the group at the beginning of the epoch.
        # Get groups_train from a ball around init_proxy_groups_train.
        if args.uniform_groups:
          train_data.proxy_groups_tensor = generate_proxy_groups_uniform(
              len(train_data), min_group_frac=args.min_group_frac).long().to(device)
        elif args.use_noise_array:
          train_data.noise_array = train_data.get_noise_array()
          train_data.proxy_groups_tensor = generate_proxy_groups_noise_array(
              train_data.proxy_groups_tensor, noise_array=train_data.noise_array).long().to(device)
        else:
          train_data.proxy_groups_tensor = generate_proxy_groups_single_noise(
              train_data.proxy_groups_tensor, noise_param=args.noise_level).long().to(device)
    model.train()
    model_last.train()
    optimizers['main'].zero_grad()
    y_pred = model(train_data.data)
    # loss = criterion(y_pred, (train_data.target-.5)*2)
    main_loss = criterion(y_pred, train_data.targets)
    const_losses['main'] = lagrangian_loss(train_data.targets, y_pred, args.epsilon, train_data.proxy_groups_tensor,
                                  criterion, device, gamma=args.dual_scale) 
    y_pred_last = model_last(train_data.data)
    const_losses['last'] = lagrangian_loss(train_data.targets, y_pred_last, args.epsilon, train_data.proxy_groups_tensor,
                                  criterion, device, gamma=args.dual_scale) 
    model.backward_dro(main_loss, const_losses, optimizers, model_last)
    model_last.load_state_dict(model.state_dict().copy())
    optimizers['main'].step()

    # Snapshot iterate once in 1000 loops.
    if e % 10 == 0:
      model.eval()
      with torch.no_grad():
        y_pred_t = model(train_data.data)
        err = error_rate(train_data.targets, y_pred_t)
        max_viol, viol_list = violation(
            train_data.targets, y_pred_t, args.epsilon, train_data.true_groups_tensor.T)

        y_pred_test =  model(test_data.data)
        err_test = error_rate(test_data.targets, y_pred_test)
        max_viol_test, viol_list_test = violation(
            test_data.targets, y_pred_test, args.epsilon, test_data.true_groups_tensor.T)

        #tb_writer.add_scalar('Train/Loss',main_loss.item(),e)
        tb_writer.add_scalar('Train/Accuracy', 1-err, e)
        tb_writer.add_scalar('Test/Accuracy',1-err_test, e)
        tb_writer.add_scalar('Train/Max_violation',max_viol, e)
        tb_writer.add_scalar('Test/Max_violation',max_viol_test, e)
        
        print("Epoch %d | Train accuracy = %.3f | Test accuracy = %.3f | Viol = %.3f | Viol_test = %.3f" %
              (e, 1-err, 1-err_test, max_viol, max_viol_test), flush=True)
  return

if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-dsdro', '--dsdro', action='store_true')
    parser.add_argument('-dro', '--dro', action='store_true')
    parser.add_argument('-fedavg', '--fedavg', action='store_true')
    parser.add_argument('-con', '--constrained', action='store_true')
    parser.add_argument('-e', '--epsilon', default=0.01, type=float) 
    parser.add_argument('-b', '--batch_size', default=128, type=int)
    parser.add_argument('-l', '--lr', default=0.01, type=float)
    parser.add_argument('-i', '--iterations',default=50000, type=int)
    parser.add_argument('-com', '--com',default=1000, type=int)
    parser.add_argument('-p', '--epochs',default=5000, type=int)
    parser.add_argument('-o', '--log',default=100, type=int)
    parser.add_argument('-g', '--dual_scale', default=0.1, type=float)
   # parser.add_argument('-d', '--logdir', default='/home/rafi/GCIVR', type=str)
    parser.add_argument('-d', '--logdir', default='/home/aditi/Professor/Untitled Folder/F/results/rafi/exp ', type=str)
 

  

    parser.add_argument('-n', '--noise_level', default=0.3, type=float)
    parser.add_argument('-un', '--use_noise_array', action='store_true')
    parser.add_argument('-ug', '--uniform_groups', action='store_true')
    parser.add_argument('-gf', '--min_group_frac', default=0.01, type=float) 
    parser.add_argument('-ft', '--group_features_type', default='full_group_vec', type=str)
    parser.add_argument('-nc', '--num_group_clusters',default=100, type=int)
    parser.add_argument('-rp', '--resample_proxy_groups', action='store_true')
    parser.add_argument('-er', '--epochs_per_resample', default=1, type=int)
    parser.add_argument('-f', '--full_step',default=100, type=int)
    parser.add_argument('-k', '--num_users',default=8, type=int)
    parser.add_argument('-rs', '--random_seed', default=3, type=int)
    parser.add_argument('--curbeta', default=0.1, type=float, help='current learning rate')
    parser.add_argument('--beta', default=0.6, type = float, help = 'momentum parameters for SCCMA')
    parser.add_argument('--curlr', default=0.01, type=float, help='current learning rate')
    parser.add_argument('--local_ep', default=1, type=int, help='local_epoch for update')
    parser.add_argument('--local_bs', type=int, default=16, help="local batch size: B")
    parser.add_argument('--lmbda', type=float, default=0.1, help="lmbda for update y_kt")
    parser.add_argument('-I', default=1, type=int, help='the frequency for FL communication to updata global model')
    parser.add_argument('-gpu', '--GPU', default=0, type=int)
    parser.add_argument('-mu', '--mu', default=0, type=float)
    parser.add_argument('--gamma_x', default=1.4, type=float, help='learning rate, gamma_x')
    parser.add_argument('--gamma_y', default=1.4, type=float, help='learning rate, gamma_y')

    args = parser.parse_args()

    args.log_dir =  increment_dir(Path(args.logdir) / 'exp')
    os.makedirs(args.log_dir)
    print( args.log_dir)
    yaml_file = str(Path(args.log_dir) / "args.yaml")
    with open(yaml_file, 'w') as out:
      yaml.dump(args.__dict__, out, default_flow_style=False)
    # Device configuration
    if args.GPU==0:
      device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
    else:
      device = torch.device("cuda:1" if torch.cuda.is_available() else 'cpu')
    if args.dro:
      train_DDRO(args)
    elif args.fedavg:
      train_fedavg(args)
    elif args.dsdro:
      train_DSDDRO(args)
    # else:
    #   train_unconstrained(args)