import argparse
import random
import numpy as np
import torch
from utils import get_logger
import sys
import os
import csv

parser = argparse.ArgumentParser()
parser.add_argument("-n", "--name", help="Name used to save the log file.", type=str, default="baseline_bgl")
parser.add_argument("-s", "--seed", help="Random seed.", type=int, default=42)
parser.add_argument("-d", "--data", 
                    help="Which dataset to run: [adult|compas|crime|law]", 
                    type=str, default="adult")
parser.add_argument("--ub", type=float, default=0.12, help="bgl upper bound")
parser.add_argument("--eps", type=float, default=0.01, help="bgl eps")
args = parser.parse_args()

np.random.seed(args.seed)
random.seed(args.seed)
torch.manual_seed(args.seed)

logger = get_logger(args.name)

logger.info("args: {}".format(args))

import pandas as pd
import matplotlib.pyplot as plt

# Loading the dataset
if args.data == "adult":
    from dataset import AdultDataset
    adult_train = AdultDataset(root_dir='data', phase='train', tar_attr="income", priv_attr="sex")
    adult_test = AdultDataset(root_dir='data', phase='test', tar_attr="income", priv_attr="sex")
    X_train, Y_train, A_train = adult_train.X, adult_train.Y.squeeze(), np.argmax(adult_train.A, axis=1)
    X_test, Y_test, A_test = adult_test.X, adult_test.Y.squeeze(), np.argmax(adult_test.A, axis=1)
    
elif args.data == "compas":
    from dataset import COMPAS
    compas = pd.read_csv("data/propublica.csv").values
    logger.debug("Shape of COMPAS dataset: {}".format(compas.shape))
    # Random shuffle and then partition by 70/30.
    num_classes = 2
    num_groups = 2
    num_insts = compas.shape[0]
    logger.info("Total number of instances in the COMPAS data: {}".format(num_insts))
    # Random shuffle and then partition by 70/30.
    num_classes = 2
    num_groups = 2
    num_insts = compas.shape[0]
    logger.info("Total number of instances in the COMPAS data: {}".format(num_insts))
    # Random shuffle the dataset.
    indices = np.arange(num_insts)
    np.random.shuffle(indices)
    compas = compas[indices]
    # Partition the dataset into train and test split.
    ratio = 0.7
    num_train = int(num_insts * ratio)
    compas_train = COMPAS(compas[:num_train, :])
    compas_test = COMPAS(compas[num_train:, :])
    X_train, Y_train, A_train = compas_train.insts, compas_train.labels.squeeze(), compas_train.attrs
    X_test, Y_test, A_test = compas_test.insts, compas_test.labels.squeeze(), compas_test.attrs
    
elif args.data == "crime":
    from dataset import CrimeDataset
    crime_train = CrimeDataset(root_dir='data', phase='train')
    crime_test = CrimeDataset(root_dir='data', phase='test')
    X_train, Y_train, A_train = crime_train.X, crime_train.Y.squeeze(), crime_train.A
    X_test, Y_test, A_test = crime_test.X, crime_test.Y.squeeze(), crime_test.A
    
elif args.data == "law":
    from dataset import LawSchoolDataset
    law_train = LawSchoolDataset(root_dir='data', phase='train')
    law_test = LawSchoolDataset(root_dir='data', phase='test')
    X_train, Y_train, A_train = law_train.X, law_train.Y.squeeze(), law_train.A
    X_test, Y_test, A_test = law_test.X, law_test.Y.squeeze(), law_test.A

else:
    raise NotImplementedError

# train a normal regressor
from fairlearn.metrics import group_summary
from sklearn.metrics import mean_squared_error
from fairlearn.reductions import BoundedGroupLoss, SquareLoss
from sklearn.linear_model import SGDRegressor, LinearRegression
import torch

regressor = LinearRegression()
regressor.fit(X_train, Y_train)
Y_test_pred = regressor.predict(X_test)
logger.info("For a normal regressor, MSE summary: {}".format(group_summary(mean_squared_error, Y_test, Y_test_pred, sensitive_features=A_test)))

from fairlearn.reductions import ExponentiatedGradient

np.random.seed(args.seed)
random.seed(args.seed)
torch.manual_seed(args.seed)

bgl_constraint = BoundedGroupLoss(SquareLoss(min_val=0.0, max_val=1.0), upper_bound=args.ub)
regressor = LinearRegression()
mitigator = ExponentiatedGradient(regressor, bgl_constraint, eps=args.eps) 
mitigator.fit(X_train, Y_train, sensitive_features=A_train)
Y_test_pred = mitigator.predict(X_test)
logger.info("Under Bounded Group Loss constraint, MSE summary: {}".format(group_summary(mean_squared_error, Y_test, Y_test_pred, sensitive_features=A_test)))
results = group_summary(mean_squared_error, Y_test, Y_test_pred, sensitive_features=A_test)
print(results['overall'], results['by_group'][0], results['by_group'][1])

# log data
cls_error = results['overall']
error_0 = results['by_group'][0]
error_1 = results['by_group'][1]


ys_var = np.var(Y_test)
r_squared = 1 - cls_error/ys_var
logger.info("R squared = {}".format(r_squared))
nmse = cls_error/ys_var
# save data to csv
csv_data = {"cls_error": cls_error,
            "error_0": error_0,
            "error_1": error_1,
            "err_gap": np.abs(error_0-error_1),
            "R^2": r_squared,
            "nmse": nmse
            }

csv_fn = args.name + ".csv"
with open(csv_fn, "a") as csv_file:
    fieldnames = ["cls_error", "error_0", "error_1", "err_gap", "R^2", "nmse"]
    writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
    if os.path.exists(csv_fn):
        pass # no need to write headers
    else:
        writer.writeheader()
    writer.writerow(csv_data)


sys.exit(0)