# coding=utf-8
"""Fairness with noisy protected groups experiments."""
import argparse
import os
import yaml
from pathlib import Path

import numpy as np
from sklearn.cluster import KMeans
import torch
import torch.nn as nn
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter

from utils import generate_proxy_groups_uniform, generate_proxy_groups_noise_array,\
                  generate_proxy_groups_single_noise, error_rate, group_error_rates, \
                  tpr, group_tprs,increment_dir, dataset_iid
from data import load_data_adult
from model import Linear
from update import adjust_curlr, LocalUpdateDDRO, FedAvg, DatasetSplit
from torch.utils.data import DataLoader, Dataset


def violation(
    labels, predictions, epsilon, groups):
  # Returns violations across different group feature thresholds.
  viol_list = []
  overall_tpr = tpr(labels, predictions).cpu()
  for kk in range(groups.shape[1]):
    group_tpr =group_tprs(labels, predictions, groups[:,kk]).cpu()
    viol_list += [overall_tpr - group_tpr - epsilon]
  return np.max(viol_list), viol_list

def lagrangian_loss(labels, predictions, epsilon, 
                    groups, criterion, device, gamma=0.1, alpha=0.1):
  # Returns the lagrangain loss of the constrained problem
  total_loss = torch.zeros(1).to(device)
  main_loss = criterion(predictions, labels)
  m = groups.shape[1]
  for kk in range(m):
    g_inds = groups[:,kk].logical_and((labels >= 1.0).squeeze())
    group_loss = criterion(predictions[g_inds],labels[g_inds])
    const_loss = group_loss - main_loss - epsilon
    total_loss += 1/(m+1) * (1/m + torch.exp(const_loss * alpha / gamma))
  return total_loss

def train_DDRO(args, device=None):
  print("clients: %d, batch size: %d, lambda: %f, I: %d, local batch: %d" %(args.num_users, args.batch_size, args.lmbda, args.I, args.local_bs))
  """Training model"""
  tb_writer = SummaryWriter(log_dir=args.log_dir)
  if device is None:
    device = torch.device('cuda')

  ## fix random seeds
  torch.manual_seed(args.random_seed)
  np.random.seed(args.random_seed)
  #criterion = nn.BCEWithLogitsLoss()
  dataset_train, dataset_test = load_data_adult(args.noise_level, 
                                          uniform_groups=args.uniform_groups, 
                                          min_group_frac=args.min_group_frac, 
                                          use_noise_array=args.use_noise_array,
                                          group_features_type=args.group_features_type, 
                                          num_group_clusters=args.num_group_clusters)

  # split the whole train dataset for each user
  dict_users = dataset_iid(dataset=dataset_train, num_users=args.num_users)

  net_glob = Linear(dataset_train.num_features).to(device)
  # print("length of model:", len(net_glob.state_dict().keys()))

 # begin training
  net_glob.train()
  # copy weights
  w_glob = net_glob.state_dict()
  w_glob_prev = deepcopy(w_glob)

  #Aggregation over all clients
  w_locals = [w_glob for i in range(args.num_users)]
  y_kt = [0.0 for i in range(args.num_users)]

  for iter in range(args.epochs):
    if args.resample_proxy_groups:
    # Only resample proxy groups every epochs_per_resample epochs.
      if round % args.epochs_per_resample == 0:
        # Resample the group at the beginning of the epoch.
        # Get groups_train from a ball around init_proxy_groups_train.
        if args.uniform_groups:
          dataset_train.proxy_groups_tensor = generate_proxy_groups_uniform(
              len(dataset_train), min_group_frac=args.min_group_frac).long().to(device)
        elif args.use_noise_array:
          dataset_train.noise_array = dataset_train.get_noise_array()
          dataset_train.proxy_groups_tensor = generate_proxy_groups_noise_array(
              dataset_train.proxy_groups_tensor, noise_array=dataset_train.noise_array).long().to(device)
        else:
          dataset_train.proxy_groups_tensor = generate_proxy_groups_single_noise(
              dataset_train.proxy_groups_tensor, noise_param=args.noise_level).long().to(device)
              
    #adjust_curlr(iter, args, optimizer=None)
    
    loss_locals = []
    m = 1
    idxs_users = np.random.choice(range(args.num_users), m, replace=False)

    # update y_kt
    for idx in idxs_users:
        local = LocalUpdateDDRO(args=args, dataset=dataset_train, idxs=dict_users[idx])
        y_kt[idx] = local.update_ykt(deepcopy(net_glob).to(device), 
                          w_glob, w_glob_prev, ykt=y_kt[idx])
    # get the y_t
    y_t = sum(y_kt)/len(y_kt)
    #writer.add_scalar("y_t", y_t, iter)
    # local update 
    for idx in idxs_users:
        local = LocalUpdateDDRO(args=args, dataset=dataset_train, idxs=dict_users[idx])
        w, loss = local.train(net=deepcopy(net_glob), y_t=y_t)
        
        w_locals[idx] = deepcopy(w)
        loss_locals.append(deepcopy(loss))
    # update global weights
    w_glob = FedAvg(w_locals)

    # copy weight to net_glob
    net_glob.load_state_dict(w_glob)

    if iter % 10 == 0:
      train_data1 = deepcopy(dataset_train)
      test_data1 = deepcopy(dataset_test)
      train_data1.to(device)
      test_data1.to(device)
      net_glob.eval()
      with torch.no_grad():
        y_pred_t = net_glob(train_data1.data)
        err = error_rate(train_data1.targets, y_pred_t)
        max_viol, viol_list = violation(
            train_data1.targets, y_pred_t, args.epsilon, train_data1.true_groups_tensor.T)

        y_pred_test =  net_glob(test_data1.data)
        err_test = error_rate(test_data1.targets, y_pred_test)

        max_viol_test, viol_list_test = violation(
            test_data1.targets, y_pred_test, args.epsilon, test_data1.true_groups_tensor.T)

        #tb_writer.add_scalar('Train/Loss',main_loss.item(),e)
        tb_writer.add_scalar('Train/Accuracy', 1-err, iter)
        tb_writer.add_scalar('Test/Accuracy',1-err_test, iter)
        tb_writer.add_scalar('Train/Max_violation',max_viol, iter)
        tb_writer.add_scalar('Test/Max_violation',max_viol_test, iter)
        
        print("Epoch %d | Train accuracy = %.3f | Test accuracy = %.3f  Viol = %.3f | Viol_test = %.3f" %
              (iter, 1-err, 1-err_test, max_viol, max_viol_test), flush=True)
  return

def train_unconstrained(args, device=None):
  """Training model"""
  tb_writer = SummaryWriter(log_dir=args.log_dir)
  if device is None:
    device = torch.device('cuda')

  ## fix random seeds
  torch.manual_seed(args.random_seed)
  np.random.seed(args.random_seed)
  #criterion = nn.BCEWithLogitsLoss()
  dataset_train, dataset_test = load_data_adult(args.noise_level, 
                                          uniform_groups=args.uniform_groups, 
                                          min_group_frac=args.min_group_frac, 
                                          use_noise_array=args.use_noise_array,
                                          group_features_type=args.group_features_type, 
                                          num_group_clusters=args.num_group_clusters)

  # split the whole train dataset for each user
  dict_users = dataset_iid(dataset=dataset_train, num_users=args.num_users)

  net_glob = Linear(dataset_train.num_features).to(device)

 # begin training
  net_glob.train()
  # copy weights
  w_glob = net_glob.state_dict()

  #Aggregation over all clients
  w_locals = [w_glob for i in range(args.num_users)]
  criterion = nn.BCEWithLogitsLoss()

  for iter in range(args.epochs):
    if args.resample_proxy_groups:
    # Only resample proxy groups every epochs_per_resample epochs.
      if iter % args.epochs_per_resample == 0:
        # Resample the group at the beginning of the epoch.
        # Get groups_train from a ball around init_proxy_groups_train.
        if args.uniform_groups:
          dataset_train.proxy_groups_tensor = generate_proxy_groups_uniform(
              len(dataset_train), min_group_frac=args.min_group_frac).long().to(device)
        elif args.use_noise_array:
          dataset_train.noise_array = dataset_train.get_noise_array()
          dataset_train.proxy_groups_tensor = generate_proxy_groups_noise_array(
              dataset_train.proxy_groups_tensor, noise_array=dataset_train.noise_array).long().to(device)
        else:
          dataset_train.proxy_groups_tensor = generate_proxy_groups_single_noise(
              dataset_train.proxy_groups_tensor, noise_param=args.noise_level).long().to(device)
    
    loss_locals = []
    m = 1
    idxs_users = np.random.choice(range(args.num_users), m, replace=False) 
    for idx in idxs_users:
        ldr_train = DataLoader(DatasetSplit(dataset_train, dict_users[idx]), batch_size=args.local_bs, shuffle=True)
        net=deepcopy(net_glob)
        net.load_state_dict(w_locals[idx])
        optim_cur =  torch.optim.Adagrad(net.parameters(), lr=args.lr)

        net.train()
        count = 0
        epoch_loss = []
        for ep in range(args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(ldr_train):
                images, labels = images.to(device), labels.to(device)
                net.zero_grad()
                log_probs= net(images)
                main_loss = criterion(log_probs, labels)
                main_loss.backward()
                optim_cur.step()
                batch_loss.append(main_loss.item())
                count += 1
                if count >= args.I:
                  break
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
            if count >= args.I:
                  break
        loss = sum(epoch_loss) / len(epoch_loss)
        w = net.state_dict()
        w_locals[idx] = deepcopy(w)
        loss_locals.append(deepcopy(loss))
    # update global weights
    w_glob = FedAvg(w_locals)

    # copy weight to net_glob
    net_glob.load_state_dict(w_glob)

    if iter % 10 == 0:
      train_data1 = deepcopy(dataset_train)
      test_data1 = deepcopy(dataset_test)
      train_data1.to(device)
      test_data1.to(device)
      net_glob.eval()
      with torch.no_grad():
        y_pred_t = net_glob(train_data1.data)
        err = error_rate(train_data1.targets, y_pred_t)
        max_viol, viol_list = violation(
            train_data1.targets, y_pred_t, args.epsilon, train_data1.true_groups_tensor.T)

        y_pred_test =  net_glob(test_data1.data)
        err_test = error_rate(test_data1.targets, y_pred_test)

        max_viol_test, viol_list_test = violation(
            test_data1.targets, y_pred_test, args.epsilon, test_data1.true_groups_tensor.T)

        #tb_writer.add_scalar('Train/Loss',main_loss.item(),e)
        tb_writer.add_scalar('Train/Accuracy', 1-err, iter)
        tb_writer.add_scalar('Test/Accuracy',1-err_test, iter)
        tb_writer.add_scalar('Train/Max_violation',max_viol, iter)
        tb_writer.add_scalar('Test/Max_violation',max_viol_test, iter)
        
        print("Epoch %d | Train accuracy = %.3f | Test accuracy = %.3f | Viol = %.3f | Viol_test = %.3f" %
              (iter, 1-err, 1-err_test, max_viol, max_viol_test), flush=True)
  return

def train_constrained(args, device=None):
  """Traqining model"""
  tb_writer = SummaryWriter(log_dir=args.log_dir)
  if device is None:
    device = torch.device('cpu')

  criterion = nn.BCEWithLogitsLoss()
  train_data, test_data = load_data_adult(args.noise_level, 
                                          uniform_groups=args.uniform_groups, 
                                          min_group_frac=args.min_group_frac, 
                                          use_noise_array=args.use_noise_array,
                                          group_features_type=args.group_features_type, 
                                          num_group_clusters=args.num_group_clusters)

  train_data.to(device)
  test_data.to(device)

  model = Linear(train_data.num_features)
  model_last = deepcopy(model).to(device)

  optimizers = {'main': torch.optim.Adagrad(model.parameters(), lr=args.lr),
                'last': torch.optim.Adagrad(model_last.parameters(), lr=args.lr)
  }
  const_losses = {'main': torch.zeros(1), 'last': torch.zeros(1)}
 

  for e in range(args.epochs):
    # Check for the beginning of a new epoch.
    if args.resample_proxy_groups and args.constrained:

      # Only resample proxy groups every epochs_per_resample epochs.
      if e % args.epochs_per_resample == 0:
        # Resample the group at the beginning of the epoch.
        # Get groups_train from a ball around init_proxy_groups_train.
        if args.uniform_groups:
          train_data.proxy_groups_tensor = generate_proxy_groups_uniform(
              len(train_data), min_group_frac=args.min_group_frac).long().to(device)
        elif args.use_noise_array:
          train_data.noise_array = train_data.get_noise_array()
          train_data.proxy_groups_tensor = generate_proxy_groups_noise_array(
              train_data.proxy_groups_tensor, noise_array=train_data.noise_array).long().to(device)
        else:
          train_data.proxy_groups_tensor = generate_proxy_groups_single_noise(
              train_data.proxy_groups_tensor, noise_param=args.noise_level).long().to(device)
    model.train()
    model_last.train()
    optimizers['main'].zero_grad()
    y_pred = model(train_data.data)
    # loss = criterion(y_pred, (train_data.target-.5)*2)
    main_loss = criterion(y_pred, train_data.targets)
    const_losses['main'] = lagrangian_loss(train_data.targets, y_pred, args.epsilon, train_data.proxy_groups_tensor,
                                  criterion, device, gamma=args.dual_scale) 
    y_pred_last = model_last(train_data.data)
    const_losses['last'] = lagrangian_loss(train_data.targets, y_pred_last, args.epsilon, train_data.proxy_groups_tensor,
                                  criterion, device, gamma=args.dual_scale) 
    model.backward_dro(main_loss, const_losses, optimizers, model_last)
    model_last.load_state_dict(model.state_dict().copy())
    optimizers['main'].step()

    # Snapshot iterate once in 1000 loops.
    if e % 10 == 0:
      model.eval()
      with torch.no_grad():
        y_pred_t = model(train_data.data)
        err = error_rate(train_data.targets, y_pred_t)
        max_viol, viol_list = violation(
            train_data.targets, y_pred_t, args.epsilon, train_data.true_groups_tensor.T)

        y_pred_test =  model(test_data.data)
        err_test = error_rate(test_data.targets, y_pred_test)
        max_viol_test, viol_list_test = violation(
            test_data.targets, y_pred_test, args.epsilon, test_data.true_groups_tensor.T)

        #tb_writer.add_scalar('Train/Loss',main_loss.item(),e)
        tb_writer.add_scalar('Train/Accuracy', 1-err, e)
        tb_writer.add_scalar('Test/Accuracy',1-err_test, e)
        tb_writer.add_scalar('Train/Max_violation',max_viol, e)
        tb_writer.add_scalar('Test/Max_violation',max_viol_test, e)
        
        print("Epoch %d | Train accuracy = %.3f | Test accuracy = %.3f | Viol = %.3f | Viol_test = %.3f" %
              (e, 1-err, 1-err_test, max_viol, max_viol_test), flush=True)
  return

if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-dro', '--DDRO', action='store_true')
    parser.add_argument('-con', '--constrained', action='store_true')
    parser.add_argument('-e', '--epsilon', default=0.01, type=float) 
    parser.add_argument('-b', '--batch_size', default=128, type=int)
    parser.add_argument('-l', '--lr', default=0.01, type=float)
    parser.add_argument('-i', '--iterations',default=50000, type=int)
    parser.add_argument('-p', '--epochs',default=5000, type=int)
    parser.add_argument('-o', '--log',default=100, type=int)
    parser.add_argument('-g', '--dual_scale', default=0.1, type=float)
    parser.add_argument('-d', '--logdir', default='/home/rafi/GCIVR', type=str)
    parser.add_argument('-n', '--noise_level', default=0.3, type=float)
    parser.add_argument('-un', '--use_noise_array', action='store_true')
    parser.add_argument('-ug', '--uniform_groups', action='store_true')
    parser.add_argument('-gf', '--min_group_frac', default=0.01, type=float) 
    parser.add_argument('-ft', '--group_features_type', default='full_group_vec', type=str)
    parser.add_argument('-nc', '--num_group_clusters',default=100, type=int)
    parser.add_argument('-rp', '--resample_proxy_groups', action='store_true')
    parser.add_argument('-er', '--epochs_per_resample', default=1, type=int)
    parser.add_argument('-f', '--full_step',default=100, type=int)
    parser.add_argument('-k', '--num_users',default=8, type=int)
    parser.add_argument('-rs', '--random_seed', default=3, type=int)
    parser.add_argument('--curbeta', default=0.1, type=float, help='current learning rate')
    parser.add_argument('--beta', default=0.6, type = float, help = 'momentum parameters for SCCMA')
    parser.add_argument('--curlr', default=0.01, type=float, help='current learning rate')
    parser.add_argument('--local_ep', default=1, type=int, help='local_epoch for update')
    parser.add_argument('--local_bs', type=int, default=16, help="local batch size: B")
    parser.add_argument('--lmbda', type=float, default=0.1, help="lmbda for update y_kt")
    parser.add_argument('-I', default=1, type=int, help='the frequency for FL communication to updata global model')

    args = parser.parse_args()

    args.log_dir =  increment_dir(Path(args.logdir) / 'exp')
    os.makedirs(args.log_dir)
    yaml_file = str(Path(args.log_dir) / "args.yaml")
    with open(yaml_file, 'w') as out:
      yaml.dump(args.__dict__, out, default_flow_style=False)
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if args.DDRO:
      train_DDRO(args)
    elif args.constrained:
      train_constrained(args)
    else:
      train_unconstrained(args)