# -*- coding: utf-8 -*-

import numpy as np
import sys
import os
import copy
import argparse

import math
import random
import csv

from util import Dataset, Result, ResultRep
from util import get_logger

def experiment(p, trainData, testData, eps, hparams, resultRep):
  trainS, trainX1, trainX2, trainY = copy.deepcopy(trainData)
  testS, testX1, testX2, testY = copy.deepcopy(testData)
  dataset = Dataset(trainS, trainX1, trainX2, trainY)
  dataset.add_testdata(testS, testX1, testX2, testY)
  avails = []
  for j in range(len(testS[0])):
    vals = set()
    vals = vals.union(set([s[j] for s in trainS]))
    #print "trainS=",trainS
    vals = vals.union(set([s[j] for s in testS]))
    #print "testS=",trainS
    avails.append(vals == set([0,1]))
    #print "j,vals=",j,vals
  # print("avails: ", avails)
  result_unfair_train, result_unfair_test = dataset.Unfair_Prediction(p.kernel or p.rff, hparams["lmd"], hparams["gamma"], avails)
  title = {}
  result_train, result_test = dataset.EpsFair_Prediction(p.dataset, eps, hparams, avails, p)
  title["hparam"] = hparams
  if p.kernel:
    title["kernel"]="kernel"
  elif p.rff and p.nonlinears:
    title["kernel"]="rff-ns"
  elif p.rff:
    title["kernel"]="rff"
  else:
    title["kernel"]="no"
  title["eps"]=eps
  title["dataset"]="train"
  resultRep.add_run(copy.deepcopy(title), result_train)
  # title["dataset"]="valid"
  # resultRep.add_run(copy.deepcopy(title), result_valid)
  title["dataset"]="test"
  resultRep.add_run(copy.deepcopy(title), result_test)
  title["eps"]="unfair"
  title["dataset"]="train"
  resultRep.add_run(copy.deepcopy(title), result_unfair_train)
  # title["dataset"]="valid"
  # resultRep.add_run(copy.deepcopy(title), result_unfair_valid)
  title["dataset"]="test"
  resultRep.add_run(copy.deepcopy(title), result_unfair_test)
  #### get our own results ####
  from fairlearn.metrics import group_summary
  from sklearn.metrics import mean_squared_error
  Y_test, Y_test_pred, A_test = result_unfair_test.Y, result_unfair_test.Yhat, result_unfair_test.S
  logger.info("For a normal regressor, MSE summary: {}".format(group_summary(mean_squared_error, Y_test, Y_test_pred, sensitive_features=A_test)))
  Y_test, Y_test_pred, A_test = result_test.Y, result_test.Yhat, result_test.S
  logger.info("For a fair-train regressor, MSE summary: {}".format(group_summary(mean_squared_error, Y_test, Y_test_pred, sensitive_features=A_test)))
  results = group_summary(mean_squared_error, Y_test, Y_test_pred, sensitive_features=A_test)

  # log data
  cls_error = results['overall']
  groups = list(results['by_group'].keys())
  assert len(groups) == 2
  error_0 = results['by_group'][groups[0]]
  error_1 = results['by_group'][groups[1]]
  ys_var = np.var(Y_test)
  r_squared = 1 - cls_error/ys_var
  logger.info("R squared = {}".format(r_squared))
  nmse = cls_error/ys_var
  # save data to csv
  csv_data = {"cls_error": cls_error,
              "error_0": error_0,
              "error_1": error_1,
              "err_gap": np.abs(error_0-error_1),
              "R^2": r_squared,
              "nmse": nmse
              }

  csv_fn = p.name + ".csv"
  with open(csv_fn, "a") as csv_file:
      fieldnames = ["cls_error", "error_0", "error_1", "err_gap", "R^2", "nmse"]
      writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
      if os.path.exists(csv_fn):
          pass # no need to write headers
      else:
          writer.writeheader()
      writer.writerow(csv_data)
    #############################

def main_single(p, trainData, testData, eps, hparams):
  resultRep = ResultRep()
  trainData, testData = copy.deepcopy(trainData), copy.deepcopy(testData)
  experiment(p, trainData, testData, eps, hparams, resultRep) 
  return resultRep

def main(p):
  dataset = p.dataset
  kernel = p.kernel
  # read data
  X1_train, X2_train, Y_train, S_train, X1_test, X2_test, Y_test, S_test = read_data(p)
  print ("S_train.shape=",S_train.shape)
  print ("X1_train.shape=",X1_train.shape)
  print ("X2_train.shape=",X2_train.shape)
  print ("Y_train.shape=",Y_train.shape)
  print("np.sum(X2)=",np.sum(X2_train))
  eps, gamma, lmd = p.eps, p.gamma, p.lmd
  hparams = {}
  hparams["gamma"] = gamma; hparams["lmd"]=lmd
  trainData = [S_train, X1_train, X2_train, Y_train]
  testData = [S_test, X1_test, X2_test, Y_test]
  resultRep = main_single(p, trainData, testData, eps, hparams)
  # print(resultRep)
  # print(resultRep.str_pretty())

def read_data(args):
  # Loading the dataset
  if args.dataset == "adult":
      from dataset import AdultDataset
      adult_train = AdultDataset(root_dir='data', phase='train', tar_attr="income", priv_attr="sex")
      adult_test = AdultDataset(root_dir='data', phase='test', tar_attr="income", priv_attr="sex")
      X_train, Y_train, A_train = adult_train.X, adult_train.Y.squeeze(), np.argmax(adult_train.A, axis=1)
      X_test, Y_test, A_test = adult_test.X, adult_test.Y.squeeze(), np.argmax(adult_test.A, axis=1)
  elif args.dataset == "compas":
      from dataset import COMPAS
      import pandas as pd
      compas = pd.read_csv("data/propublica.csv").values
      logger.debug("Shape of COMPAS dataset: {}".format(compas.shape))
      # Random shuffle and then partition by 70/30.
      num_classes = 2
      num_groups = 2
      num_insts = compas.shape[0]
      logger.info("Total number of instances in the COMPAS data: {}".format(num_insts))
      # Random shuffle and then partition by 70/30.
      num_classes = 2
      num_groups = 2
      num_insts = compas.shape[0]
      logger.info("Total number of instances in the COMPAS data: {}".format(num_insts))
      # Random shuffle the dataset.
      indices = np.arange(num_insts)
      np.random.shuffle(indices)
      compas = compas[indices]
      # Partition the dataset into train and test split.
      ratio = 0.7
      num_train = int(num_insts * ratio)
      compas_train = COMPAS(compas[:num_train, :])
      compas_test = COMPAS(compas[num_train:, :])
      X_train, Y_train, A_train = compas_train.insts, compas_train.labels.squeeze(), compas_train.attrs
      X_test, Y_test, A_test = compas_test.insts, compas_test.labels.squeeze(), compas_test.attrs
  elif args.dataset == "crime":
      from dataset import CrimeDataset
      crime_train = CrimeDataset(root_dir='data', phase='train')
      crime_test = CrimeDataset(root_dir='data', phase='test')
      X_train, Y_train, A_train = crime_train.X, crime_train.Y.squeeze(), crime_train.A
      X_test, Y_test, A_test = crime_test.X, crime_test.Y.squeeze(), crime_test.A
  elif args.dataset == "law":
      from dataset import LawSchoolDataset
      law_train = LawSchoolDataset(root_dir='data', phase='train')
      law_test = LawSchoolDataset(root_dir='data', phase='test')
      X_train, Y_train, A_train = law_train.X, law_train.Y.squeeze(), law_train.A
      X_test, Y_test, A_test = law_test.X, law_test.Y.squeeze(), law_test.A
  else:
      raise NotImplementedError
  # create X2
  X2_train, X2_test = np.ones((X_train.shape[0],1)), np.ones((X_test.shape[0],1))
  # unsqueeze S (or A) to make S (or A) two-dimensional
  A_train, A_test = np.expand_dims(A_train, axis=-1), np.expand_dims(A_test, axis=-1)
  return X_train, X2_train, Y_train, A_train, X_test, X2_test, Y_test, A_test

if __name__ == '__main__':
  parser = argparse.ArgumentParser(description='process some integers.')
  parser.add_argument("--name", help="Name used to save the log file.", type=str, default="baseline_coff")
  parser.add_argument('-d', '--dataset', \
        action='store', \
        nargs='?', \
        const=None, \
        default="adult", \
        type=str, \
        choices=["adult", "compas", "crime", "law"], \
        help='dataset', \
        metavar=None)
  parser.add_argument('-k', '--kernel', action='store_true')
  parser.add_argument('-r', '--rff', action='store_true')
  parser.add_argument('-n', '--nonlinears', action='store_true')
  parser.add_argument('-s', '--seed', action='store', type=int, default=42) #rand seed
  parser.add_argument('-e', '--eps', action='store', type=float, default=0.1) #eps (fairness parameter)
  parser.add_argument('-g', '--gamma', action='store', type=float, default=1.0) #gamma (hyperparameter)
  parser.add_argument('-l', '--lmd', action='store', type=float, default=0.0) #lambda (hyperparameter)
  args = parser.parse_args()
  logger = get_logger(args.name)
  logger.info("args: {}".format(args))
  np.random.seed(args.seed)
  random.seed(args.seed)
  main(args)


