import torch
import argparse
import os
import matplotlib.pyplot as plt
import numpy as np
import itertools
import math
import random
import shap
import copy
from tqdm import tqdm
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from scipy.stats import pearsonr, spearmanr
from scipy.stats import norm

import torch.optim as optim
from train_utils import train_mlp, test_mlp, train_sklearn, test_sklearn, get_savefldr, get_loadfile
from dataloader import get_folktables_dataset
from models import get_model, flatten_weights
from helper import init_to_name_dict, init_to_color_dict, id_to_marker_dict, id_to_arch_dict
from meta_fairness import equal_opp_binary, fair_loss_binary

from fairtorch_local import DemographicParityLoss, EqualiedOddsLoss

import warnings
warnings.filterwarnings("ignore")

torch.autograd.set_detect_anomaly(True)

parser = argparse.ArgumentParser()
parser.add_argument("--arch", default="mlp_64", help="Model Architecture")
parser.add_argument("--loss", default="ce", help="Loss Type for Training/Finetuning")
parser.add_argument("--type", default="acsincome", help="Dataset Type")
parser.add_argument("--pc", default="sex", help="Protected Class")
parser.add_argument("--year", type=int, default=2018, help="Dataset Year")
parser.add_argument("--state", default="CA", help="Dataset State")
parser.add_argument("--ckptfldr", default="folktables2018CA", help="Folder for Saving Model Files")
parser.add_argument("--ckptloss", default=None, help="Loss Type for Training")
parser.add_argument("--shuffle_seed", type=int, default=0, help="Seed for mini-batch shuffling")
parser.add_argument("--init_seed", type=int, default=0, help="Seed for weight initialization")
parser.add_argument("--num_feat", type=int, default=10, help="Number of Input Features in Task")
parser.add_argument("--epochs", type=int, default=150, help="Number of Training Epochs")
parser.add_argument("--lr", type=float, default=0.001, help="Learning Rate for Training")
parser.add_argument("--lmbd", type=float, default=1, help="Combinatory Factor for Fair Loss")
parser.add_argument("--gpus", default="0,1", help="GPU Device ID to use")
parser.add_argument("--cuda", action="store_true", help="Use CUDA")
parser.add_argument("--preprocess_pertubate", action="store_true", help="Use pertubation data preprocessing. Hard coded to only work for class Sex")
parser.add_argument("--preprocess_upsample", action="store_true", help="Use upsampling data preprocessing. Hard coded to only work for class Sex")
parser.add_argument("--fairbatch", action="store_true", help="Use FairBatch")
parser.add_argument("--reweigh", action="store_true", help="Use Reweighing")
parser.add_argument("--gradnoise", action="store_true", help="Add Noise to Gradient")

args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus

we_init, arch_id, init_seed, shuffle_seed = 'kaiming_normal', args.arch, args.init_seed, args.shuffle_seed

torch.manual_seed(init_seed)
model = get_model(arch_id, args.num_feat, id_to_arch_dict[arch_id], we_init, ckpt=None, cuda=args.cuda)
savefldr = get_savefldr(args.ckptfldr, args.pc, args.loss, we_init, arch_id + '_' + str(init_seed), shuffle_seed)

trainloader, validloader, testloader = get_folktables_dataset(train_valid_test_split=[0.7, 0.1, 0.2],
    protected_class=args.pc, survey_year=args.year, states=[args.state], shuffle_seed=shuffle_seed, fairbatch=args.fairbatch, model=model, dataset_type=args.type)

train_mlp(model, trainloader, savefldr, validloader=validloader, cuda=args.cuda, losstype=args.loss, epochs=args.epochs, lr=args.lr, lmbd=args.lmbd,
            reweigh=args.reweigh, gradnoise=args.gradnoise)
