#%%
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
import argparse
from utils import save_checkpoint, load_checkpoint
import os.path
from torch.utils.tensorboard import SummaryWriter, writer
import mmd
import time


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class Model(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.net = nn.Sequential(
            nn.Linear(self.dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        y = self.net(x)
        return y


def loss_func(y_pred, y0_pred, y1_pred, y, lambda_):
    prediction_loss = F.mse_loss(y_pred, y)

    if lambda_ == 0:
        return prediction_loss, 0, prediction_loss

    unfairness = mmd.mix_rbf_mmd2(y0_pred, y1_pred)
    loss = prediction_loss + lambda_ * unfairness
    return prediction_loss, unfairness, loss


def get_variable_index(d, protected, outcome, NonDes):
    index_full = [i for i in list(range(0, d)) if i != outcome]
    index_unaware = [i for i in list(range(0, d)) if i != protected and i != outcome]
    index_IFair = list(set(NonDes).difference([outcome]))
    return index_full, index_unaware, index_IFair


def data_preprocess(ground_truth=True, process="Train"):
    index, intervene_index = process2index[process]
    data = observational_data[index, ]
    # To extract the interventional dataset for ground_truth e-IFair model; extract the interventional test set.
    if ground_truth == True or process == "Test":
        interventional0_data = interventional0_dataset_truth[intervene_index, ]
        interventional1_data = interventional1_dataset_truth[intervene_index, ]
    else:
        interventional0_data = interventional0_dataset[intervene_index, ]
        interventional1_data = interventional1_dataset[intervene_index, ]
    return data, interventional0_data, interventional1_data


def train(model, optimizer, num_iters=2000, eval_every=10, lambda_=0.0, mode="Full", ground_truth=True):
    training_data, training_interventional0_data, training_interventional1_data = data_preprocess(ground_truth,
                                                                                                  process="Train")
    val_data, val_interventional0_data, val_interventional1_data = data_preprocess(ground_truth, process="Validation")

    if mode == "Full" or mode == "e-IFair":
        x = torch.from_numpy(training_data[:, index_full])
        x0 = torch.from_numpy(training_interventional0_data[:, index_full])
        x1 = torch.from_numpy(training_interventional1_data[:, index_full])
        x_val = torch.from_numpy(val_data[:, index_full])
        x0_val = torch.from_numpy(val_interventional0_data[:, index_full])
        x1_val = torch.from_numpy(val_interventional1_data[:, index_full])
    if mode == "Unaware":
        x = torch.from_numpy(training_data[:, index_unaware])
        x0 = torch.from_numpy(training_interventional0_data[:, index_unaware])
        x1 = torch.from_numpy(training_interventional1_data[:, index_unaware])
        x_val = torch.from_numpy(val_data[:, index_unaware])
        x0_val = torch.from_numpy(val_interventional0_data[:, index_unaware])
        x1_val = torch.from_numpy(val_interventional1_data[:, index_unaware])
    if mode == "Fair":
        x = torch.from_numpy(training_data[:, index_IFair])
        x0 = torch.from_numpy(training_interventional0_data[:, index_IFair])
        x1 = torch.from_numpy(training_interventional1_data[:, index_IFair])
        x_val = torch.from_numpy(val_data[:, index_IFair])
        x0_val = torch.from_numpy(val_interventional0_data[:, index_IFair])
        x1_val = torch.from_numpy(val_interventional1_data[:, index_IFair])
    y = torch.unsqueeze(torch.from_numpy(training_data[:, outcome]), 1)
    y_val = torch.unsqueeze(torch.from_numpy(val_data[:, outcome]), 1)

    # training loop
    best_val_loss = float("Inf")
    model.train()
    for i in range(num_iters):
        y_pred = model(x)
        y0_pred = model(x0)
        y1_pred = model(x1)
        train_prediction_loss, train_unfairness, train_loss = loss_func(y_pred, y0_pred, y1_pred, y, lambda_=lambda_)
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        # evaluation step
        if i % eval_every == 0:
            model.eval()
            with torch.no_grad():
                y_val_pred = model(x_val)
                y0_val_pred = model(x0_val)
                y1_val_pred = model(x1_val)
                val_prediction_loss, val_unfairness, val_loss = loss_func(y_val_pred, y0_val_pred, y1_val_pred, y_val,
                                                                          lambda_=lambda_)

                # Record training loss from each iter into the writer
                writer.add_scalar(mode + str(lambda_) + '/' + str(ground_truth) + '/Train/Loss', train_loss.item(), i)
                writer.add_scalar(mode + str(lambda_) + '/' + str(ground_truth) + '/Train/PredictionLoss', train_prediction_loss.item(), i)
                writer.add_scalar(mode + str(lambda_) + '/' + str(ground_truth) + '/Train/Unfairness', train_unfairness.item() if lambda_!=0 else 0, i)
                writer.flush()
                # Record validation loss from each iter into the writer
                writer.add_scalar(mode + str(lambda_) + '/' + str(ground_truth) + '/Validation/Loss', val_loss.item(), i)
                writer.add_scalar(mode + str(lambda_) + '/' + str(ground_truth) + '/Validation/PredictionLoss', val_prediction_loss.item(), i)
                writer.add_scalar(mode + str(lambda_) + '/' + str(ground_truth) + '/Validation/Unfairness', val_unfairness.item() if lambda_!=0 else 0, i)
                writer.flush()

                # print progress
                print(f"iter: {i}, trian ttloss: {round(train_loss.item(),2)}, rmse={round(train_prediction_loss.item(),2)}, unfair={round(train_unfairness.item() if lambda_!=0 else 0,2)}")
                print(f"------validation ttloss: {round(val_loss.item(),2)}, rmse={round(val_prediction_loss.item(),2)}, unfair={round(val_unfairness.item() if lambda_!=0 else 0,2)}")
                # checkpoint
                if best_val_loss > val_loss:
                    best_val_loss = val_loss
                    save_checkpoint(f'{dir}/Model_{mode}.pt', model, best_val_loss)
    return


def evaluation(model, mode="Full", lambda_=0.0, ground_truth=True):
    test_data, test_interventional0_data, test_interventional1_data = data_preprocess(ground_truth, process="Test")
    # test_data, test_interventional0_data, test_interventional1_data = data_preprocess(ground_truth, process="Train")
    if mode == "Full" or mode == "e-IFair":
        x_test = torch.from_numpy(test_data[:, index_full])
        x0_test = torch.from_numpy(test_interventional0_data[:, index_full])
        x1_test = torch.from_numpy(test_interventional1_data[:, index_full])
    if mode == "Unaware":
        x_test = torch.from_numpy(test_data[:, index_unaware])
        x0_test = torch.from_numpy(test_interventional0_data[:, index_unaware])
        x1_test = torch.from_numpy(test_interventional1_data[:, index_unaware])
    if mode == "Fair":
        x_test = torch.from_numpy(test_data[:, index_IFair])
        x0_test = torch.from_numpy(test_interventional0_data[:, index_IFair])
        x1_test = torch.from_numpy(test_interventional1_data[:, index_IFair])
    y_test = torch.unsqueeze(torch.from_numpy(test_data[:, outcome]), 1)
    if mode == "Full" or mode == "Unaware" or mode == "Fair" or (mode == 'e-IFair' and lambda_==0):
        lambda_ = 1e-16
        assert(lambda_ != 0)

    model.eval()
    with torch.no_grad():
        y_test_pred = model(x_test)
        y0_test_pred = model(x0_test)
        y1_test_pred = model(x1_test)
        test_prediction_loss, test_unfairness, test_loss = loss_func(y_test_pred, y0_test_pred, y1_test_pred, y_test,
                                                                     lambda_=lambda_)
        test_prediction_loss = torch.sqrt(test_prediction_loss)
        test_loss_result = (test_loss, test_prediction_loss, test_unfairness)
    return test_loss_result, y0_test_pred, y1_test_pred


#%%
# ------Parameters setting------
# seed
seed = 532
set_random_seed(seed)
parser = argparse.ArgumentParser()
parser.add_argument('num_of_nodes', type=int, help='the number of nodes')
parser.add_argument('num_of_edges', type=int, help='the number of edges')
parser.add_argument('num_of_graphs', type=int, help='the number of graphs')
parser.add_argument('num_of_admissible_vars', type=int, help='the number of admissible variables')
args = parser.parse_args()
# number of nodes
d = args.num_of_nodes
# number of edges
s = args.num_of_edges
# number of graphs
k = args.num_of_graphs
# num_of_admissible_vars
adm = args.num_of_admissible_vars
print("{} nodes {} edges {} graphs {} admissible variables: ".format(d, s, k, adm))

# data file path
dir = "Repository_adm={}/{}nodes{}edges".format(adm, d, s)
writer = SummaryWriter("{}/log_{}.log".format(dir, k))
observational_data_path = "{}/observational_data_{}.csv".format(dir, k)
counterfactual_data_path = "{}/counterfactual_data_{}.csv".format(dir, k)
interventional0_data_truth_path = "{}/interventional0_data_truth_{}.csv".format(dir, k)
interventional1_data_truth_path = "{}/interventional1_data_truth_{}.csv".format(dir, k)
interventional0_data_gene_path = "{}/interventional0_data_gene_{}.csv".format(dir, k)
interventional1_data_gene_path = "{}/interventional1_data_gene_{}.csv".format(dir, k)
config_path = "{}/config_{}.txt".format(dir, k)
relation_path = "{}/relation_{}.txt".format(dir, k)

observational_data = np.genfromtxt(observational_data_path, skip_header=0, delimiter=',').astype(np.float32)
counterfactual_data = np.genfromtxt(counterfactual_data_path, skip_header=0, delimiter=',').astype(np.float32)
interventional0_dataset = np.genfromtxt(interventional0_data_gene_path, skip_header=0, delimiter=',').astype(np.float32)
interventional1_dataset = np.genfromtxt(interventional1_data_gene_path, skip_header=0, delimiter=',').astype(np.float32)
interventional0_dataset_truth = np.genfromtxt(interventional0_data_truth_path, skip_header=0, delimiter=',').astype(
    np.float32)
interventional1_dataset_truth = np.genfromtxt(interventional1_data_truth_path, skip_header=0, delimiter=',').astype(
    np.float32)

config = pd.read_csv(config_path, delimiter=',').T.to_dict()[0]
# Import config
protected = config['protected'] - 1
outcome = config['outcome'] - 1
n = len(observational_data) # config['sample_size'] # size of the observational data
intervene_n = len(interventional0_dataset_truth)    # config['interventional_size'] # size of the ground-truth interventional data
# Import ancestral relations
Lines = open(relation_path, 'r').readlines()
relation = {'defNonDes': [int(i) - 1 for i in Lines[0].strip().split()],
            'possDes': [int(i) - 1 for i in Lines[1].strip().split()],
            'defDes': [int(i) - 1 for i in Lines[2].strip().split()],
            'NonDes': [int(i) - 1 for i in Lines[3].strip().split()],
            'Des': [int(i) - 1 for i in Lines[4].strip().split()]}


#%%
# Data spliting
test_index = np.random.choice(np.arange(0, n), size=int(0.1 * n), replace=False)
val_index = np.random.choice(np.array(list(set(np.arange(0, n)).difference(test_index))), size=int(0.1 * n),
                             replace=False)
training_index = np.array(list(set(np.arange(0, n)).difference(np.concatenate((test_index, val_index), axis=0))))

intervene_test_index = np.arange(0, intervene_n)[intervene_n//2:]
intervene_val_index = np.random.choice(np.arange(0, intervene_n//2), size=int(0.2 * intervene_n//2), replace=False)
intervene_training_index = np.array(list(set(np.arange(0, intervene_n//2)).difference(intervene_val_index)))


process2index = {"Train": (training_index, intervene_training_index), "Validation": (val_index, intervene_val_index),
                 "Test": (test_index, intervene_test_index)}
index_full, index_unaware, index_IFair = get_variable_index(d, protected, outcome, relation['NonDes'])
mode_var = {'Full': index_full, 'Unaware': index_unaware, 'Fair': index_IFair}


#%%
Iter = 400  # for lambda < 10
Iter_large = 500 # when lambda>=10

#%%
####################################
#### Train and predict Baselines
####################################
lambda_ = 0.0; num_iters = Iter
RMSE = np.zeros([1, len(mode_var)])
Unfairness = np.zeros([1, len(mode_var)])
for i, mode in enumerate(mode_var.keys()):
    if mode == 'Full':
        num_iters = Iter_large
    set_random_seed(seed)
    var_ind = mode_var[mode]
    if var_ind == []:
        RMSE[0, i] = np.nan
        Unfairness[0, i] = np.nan
    RMSE_cv = []
    Unfairness_cv = []
    # Train
    print("\n----------- Start Training {} with lambda={} -----------".format(mode, lambda_))
    model = Model(dim=len(var_ind))
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0)
    train(model, optimizer, num_iters=num_iters, lambda_=lambda_, mode=mode)
    # Predict
    best_model = Model(dim=len(var_ind))
    load_checkpoint(f'{dir}/Model_{mode}.pt', best_model)
    test_loss_result, y0_test_pred, y1_test_pred = evaluation(best_model, mode=mode, lambda_=lambda_)
    test_loss, test_RMSE, test_unfairness = test_loss_result
    RMSE_cv.append(test_RMSE.item())
    Unfairness_cv.append(test_unfairness.item())
    print("----------- Finished Training {} with lambda={}! -----------\n".format(mode, lambda_))
    RMSE[0, i] = round(sum(RMSE_cv) / len(RMSE_cv), 3)
    Unfairness[0, i] = round(sum(Unfairness_cv) / len(Unfairness_cv), 3)

    # # save predicted y
    y_test_pred_dir = "y_pred"  # s
    os.makedirs("{}/{}".format(dir, y_test_pred_dir), exist_ok=True)
    pd.DataFrame(y0_test_pred.view(1, -1).numpy()).to_csv("{}/{}/{}_y0.csv".format(dir, y_test_pred_dir, mode),
                                                  mode='a', index=False,
                                                  header=False)
    pd.DataFrame(y1_test_pred.view(1, -1).numpy()).to_csv("{}/{}/{}_y1.csv".format(dir, y_test_pred_dir, mode),
                                                  mode='a', index=False,
                                                  header=False)

print(RMSE)
print(Unfairness)

## save data
if (not os.path.exists("{}/RMSE.csv".format(dir))) or (not os.path.exists("{}/Unfairness.csv".format(dir))):
    pd_header = pd.DataFrame(columns=mode_var.keys())
    pd.DataFrame(pd_header).to_csv("{}/RMSE.csv".format(dir), index=False, header=True)
    pd.DataFrame(pd_header).to_csv("{}/Unfairness.csv".format(dir), index=False, header=True)
pd.DataFrame(RMSE).to_csv("{}/RMSE.csv".format(dir), mode='a', index=False, header=False)
pd.DataFrame(Unfairness).to_csv("{}/Unfairness.csv".format(dir), mode='a', index=False, header=False)

#%%
##############################
# Train the model e-IFair (with ground_truth interventionals)
##############################
lambdas = [0, 0.5, 5, 20, 60, 100]
mode = 'e-IFair'
RMSEIF_truth = np.zeros([1, len(lambdas)])
UnfairnessIF_truth = np.zeros([1, len(lambdas)])
for i, lambda_ in enumerate(lambdas):
    num_iters = Iter_large if lambda_ > 5 else Iter
    set_random_seed(seed)
    RMSEIF_cv = []
    UnfairnessIF_cv = []
    model = Model(dim=len(index_full))
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0)
    train(model, optimizer, num_iters=num_iters, lambda_=lambda_, mode=mode)
    best_model = Model(dim=len(index_full))
    load_checkpoint(f'{dir}/Model_{mode}.pt', best_model)
    test_loss_result, y0_test_pred, y1_test_pred = evaluation(best_model, mode=mode, lambda_=lambda_)
    test_loss, test_RMSE, test_unfairness = test_loss_result
    RMSEIF_cv.append(test_RMSE.item())
    UnfairnessIF_cv.append(test_unfairness.item())
    print("Finished Training {} with lambda={}!".format(mode, lambda_))
    RMSEIF_truth[0, i] = round(sum(RMSEIF_cv) / len(RMSEIF_cv), 3)
    UnfairnessIF_truth[0, i] = round(sum(UnfairnessIF_cv) / len(UnfairnessIF_cv), 3)

    # # save predicted y
    y_test_pred_dir = "y_pred"
    os.makedirs("{}/{}".format(dir, y_test_pred_dir), exist_ok=True)
    
    pd.DataFrame(y0_test_pred.view(1, -1).numpy()).to_csv(
        "{}/{}/{}_{}_y0_real.csv".format(dir, y_test_pred_dir, mode, lambda_), mode='a', index=False, header=False)
    pd.DataFrame(y1_test_pred.view(1, -1).numpy()).to_csv(
        "{}/{}/{}_{}_y1_real.csv".format(dir, y_test_pred_dir, mode, lambda_), mode='a', index=False, header=False)

print(RMSEIF_truth)
print(UnfairnessIF_truth)
# # # save data
if (not os.path.exists("{}/RMSEIF_truth.csv".format(dir))) or (
        not os.path.exists("{}/UnfairnessIF_truth.csv".format(dir))):
    pd_header = pd.DataFrame(columns=[str(i) + "IF" for i in lambdas])
    pd.DataFrame(pd_header).to_csv("{}/RMSEIF_truth.csv".format(dir), index=False, header=True)
    pd.DataFrame(pd_header).to_csv("{}/UnfairnessIF_truth.csv".format(dir), index=False, header=True)
pd.DataFrame(RMSEIF_truth).to_csv("{}/RMSEIF_truth.csv".format(dir), mode='a', index=False, header=False)
pd.DataFrame(UnfairnessIF_truth).to_csv("{}/UnfairnessIF_truth.csv".format(dir), mode='a', index=False, header=False)

#%%
# ################################
# ### Train the model e-IFair (with generated interventionals)
# ################################
lambdas = [0, 0.5, 5, 20, 60, 100]
mode = 'e-IFair'
RMSEIF = np.zeros([1, len(lambdas)])
UnfairnessIF = np.zeros([1, len(lambdas)])
for i, lambda_ in enumerate(lambdas):
    set_random_seed(seed)
    RMSEIF_cv = []
    UnfairnessIF_cv = []
    model = Model(dim=len(index_full))
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0)

    num_iters = Iter_large if lambda_ > 5 else Iter
    train(model, optimizer, num_iters=num_iters, lambda_=lambda_, mode=mode, ground_truth=False)
    best_model = Model(dim=len(index_full))
    load_checkpoint(f'{dir}/Model_{mode}.pt', best_model)
    test_loss_result, y0_test_pred, y1_test_pred = evaluation(best_model, mode=mode, lambda_=lambda_,
                                                       ground_truth=False)
    test_loss, test_RMSE, test_unfairness = test_loss_result
    RMSEIF_cv.append(test_RMSE.item())
    UnfairnessIF_cv.append(test_unfairness.item())
    print("Finished Training {} with lambda={}!".format(mode, lambda_))
    RMSEIF[0, i] = round(sum(RMSEIF_cv) / len(RMSEIF_cv), 3)
    UnfairnessIF[0, i] = round(sum(UnfairnessIF_cv) / len(UnfairnessIF_cv), 3)

    # # save predicted y
    y_test_pred_dir = "y_pred"
    os.makedirs("{}/{}".format(dir, y_test_pred_dir), exist_ok=True)
    
    pd.DataFrame(y0_test_pred.view(1, -1).numpy()).to_csv(
        "{}/{}/{}_{}_y0_gene.csv".format(dir, y_test_pred_dir, mode, lambda_), mode='a', index=False, header=False)
    pd.DataFrame(y1_test_pred.view(1, -1).numpy()).to_csv(
        "{}/{}/{}_{}_y1_gene.csv".format(dir, y_test_pred_dir, mode, lambda_), mode='a', index=False, header=False)
#
print(RMSEIF)
print(UnfairnessIF)
# # # save data
if (not os.path.exists("{}/RMSEIF.csv".format(dir))) or (not os.path.exists("{}/UnfairnessIF.csv".format(dir))):
    pd_header = pd.DataFrame(columns=[str(i) + "IF" for i in lambdas])
    pd.DataFrame(pd_header).to_csv("{}/RMSEIF.csv".format(dir), index=False, header=True)
    pd.DataFrame(pd_header).to_csv("{}/UnfairnessIF.csv".format(dir), index=False, header=True)
pd.DataFrame(RMSEIF).to_csv("{}/RMSEIF.csv".format(dir), mode='a', index=False, header=False)
pd.DataFrame(UnfairnessIF).to_csv("{}/UnfairnessIF.csv".format(dir), mode='a', index=False, header=False)

# %%
import time
print(time.time())

import datetime
print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))


