import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
import argparse
from utils import save_checkpoint, load_checkpoint
import os.path


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class Model(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.net = nn.Sequential(
            nn.Linear(self.dim, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, x):
        y = self.net(x)
        return y


def loss_func(y_pred, y0_pred, y1_pred, y, lambda_):
    prediction_loss = torch.sqrt(F.mse_loss(y_pred, y))
    unfairness = torch.abs(torch.mean(y0_pred) - torch.mean(y1_pred))
    loss = prediction_loss + lambda_ * unfairness
    return prediction_loss, unfairness, loss


def get_variable_index(d, protected, outcome, NonDes):
    index_full = [i for i in list(range(0, d)) if i != outcome]
    index_unaware = [i for i in list(range(0, d)) if i != protected and i != outcome]
    index_IFair = list(set(NonDes).difference([outcome]))
    return index_full, index_unaware, index_IFair


def data_preprocess(ground_truth=True, process="Train"):
    index = process2index[process]
    data = observational_data[index,]
    if ground_truth == True or (process == "Test" and ground_truth == False):
        interventional0_data = interventional0_dataset_truth[index,]
        interventional1_data = interventional1_dataset_truth[index,]
    else:
        interventional0_data = interventional0_dataset[index,]
        interventional1_data = interventional1_dataset[index,]
    return (data, interventional0_data, interventional1_data)


def train(model, optimizer, num_iters=100, eval_every=10, lambda_=0.0, mode="Full", ground_truth=True):
    training_data, training_interventional0_data, training_interventional1_data = data_preprocess(ground_truth, process="Train")
    val_data, val_interventional0_data, val_interventional1_data = data_preprocess(ground_truth, process="Validation")

    if mode == "Full" or mode == "e-IFair":
        x = torch.from_numpy(training_data[:, index_full])
        x0 = torch.from_numpy(training_interventional0_data[:, index_full])
        x1 = torch.from_numpy(training_interventional1_data[:, index_full])
        x_val = torch.from_numpy(val_data[:, index_full])
        x0_val = torch.from_numpy(val_interventional0_data[:, index_full])
        x1_val = torch.from_numpy(val_interventional1_data[:, index_full])
    if mode == "Unaware":
        x = torch.from_numpy(training_data[:, index_unaware])
        x0 = torch.from_numpy(training_interventional0_data[:, index_unaware])
        x1 = torch.from_numpy(training_interventional1_data[:, index_unaware])
        x_val = torch.from_numpy(val_data[:, index_unaware])
        x0_val = torch.from_numpy(val_interventional0_data[:, index_unaware])
        x1_val = torch.from_numpy(val_interventional1_data[:, index_unaware])
    if mode == "Fair":
        x = torch.from_numpy(training_data[:, index_IFair])
        x0 = torch.from_numpy(training_interventional0_data[:, index_IFair])
        x1 = torch.from_numpy(training_interventional1_data[:, index_IFair])
        x_val = torch.from_numpy(val_data[:, index_IFair])
        x0_val = torch.from_numpy(val_interventional0_data[:, index_IFair])
        x1_val = torch.from_numpy(val_interventional1_data[:, index_IFair])
    y = torch.unsqueeze(torch.from_numpy(training_data[:, outcome]), 1)
    y_val = torch.unsqueeze(torch.from_numpy(val_data[:, outcome]), 1)

    train_loss_list = []
    train_prediction_loss_list = []
    train_unfairness_list = []
    val_loss_list = []
    val_prediction_loss_list = []
    val_unfairness_list = []

    # training loop
    best_val_loss = float("Inf")
    model.train()
    for i in range(num_iters):
        y_pred = model(x)
        y0_pred = model(x0)
        y1_pred = model(x1)
        train_prediction_loss, train_unfairness, train_loss = loss_func(y_pred, y0_pred, y1_pred, y, lambda_=lambda_)
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        train_loss_list.append(train_loss.item())
        train_prediction_loss_list.append(train_prediction_loss.item())
        train_unfairness_list.append(train_unfairness.item())
        # evaluation step
        if i % eval_every == 0:
            model.eval()
            with torch.no_grad():
                y_val_pred = model(x_val)
                y0_val_pred = model(x0_val)
                y1_val_pred = model(x1_val)
                val_prediction_loss, val_unfairness, val_loss = loss_func(y_val_pred, y0_val_pred, y1_val_pred, y_val,
                                                                          lambda_=lambda_)
                val_loss_list.append(val_loss.item())
                val_prediction_loss_list.append(val_prediction_loss.item())
                val_unfairness_list.append(val_unfairness.item())
                # print progress
                #                 print(i, train_loss.item(), train_prediction_loss.item(), train_unfairness.item(), val_loss.item(), val_prediction_loss.item(), val_unfairness.item())
                # print(
                #     'Step [{}/{}], Train Loss: {:.4f}, Train Prediction Loss: {:.4f}, Train Unfairness: {:.4f}, Valid Loss: {:.4f}, Valid Prediction Loss: {:.4f}, Valid Unfairness: {:.4f},'
                #     .format(i, num_iters, train_loss, train_prediction_loss, train_unfairness, val_loss,
                #             val_prediction_loss, val_unfairness))
                # checkpoint
                if best_val_loss > val_loss:
                    best_val_loss = val_loss
                    save_checkpoint(f'{dir}/Model_{mode}.pt', model, best_val_loss)
    train_loss_result = (train_loss_list, train_prediction_loss_list, train_unfairness_list)
    val_loss_result = (val_loss_list, val_prediction_loss_list, val_unfairness_list)
    # print("Finished Training {} with lambda={}!".format(mode, lambda_))
    return train_loss_result, val_loss_result


def evaluation(model, mode="Full", lambda_=0.0, ground_truth=True):
    test_data, test_interventional0_data, test_interventional1_data = data_preprocess(ground_truth, process="Test")
    if mode == "Full" or mode == "e-IFair":
        x_test = torch.from_numpy(test_data[:, index_full])
        x0_test = torch.from_numpy(test_interventional0_data[:, index_full])
        x1_test = torch.from_numpy(test_interventional1_data[:, index_full])
    if mode == "Unaware":
        x_test = torch.from_numpy(test_data[:, index_unaware])
        x0_test = torch.from_numpy(test_interventional0_data[:, index_unaware])
        x1_test = torch.from_numpy(test_interventional1_data[:, index_unaware])
    if mode == "Fair":
        x_test = torch.from_numpy(test_data[:, index_IFair])
        x0_test = torch.from_numpy(test_interventional0_data[:, index_IFair])
        x1_test = torch.from_numpy(test_interventional1_data[:, index_IFair])
    y_test = torch.unsqueeze(torch.from_numpy(test_data[:, outcome]), 1)
    model.eval()
    with torch.no_grad():
        y_test_pred = model(x_test)
        y0_test_pred = model(x0_test)
        y1_test_pred = model(x1_test)
        test_prediction_loss, test_unfairness, test_loss = loss_func(y_test_pred, y0_test_pred, y1_test_pred, y_test,
                                                                     lambda_=lambda_)
        test_loss_result = (test_loss, test_prediction_loss, test_unfairness)
    return test_loss_result


# ------Parameters setting------
# seed
set_random_seed(532)
parser = argparse.ArgumentParser()
parser.add_argument('num_of_nodes', type=int, help='the number of nodes')
parser.add_argument('num_of_edges', type=int, help='the number of edges')
parser.add_argument('num_of_graphs', type=int, help='the number of graphs')
args = parser.parse_args()
# number of nodes
d = args.num_of_nodes
# number of edges
s = args.num_of_edges
# number of graphs
k = args.num_of_graphs
print("{} nodes {} edges {} graphs: ".format(d, s, k))

# data file path
dir = "Repository/{}nodes{}edges".format(d, s)
observational_data_path = "{}/observational_data_{}.csv".format(dir, k)
# counterfactual_data_path = "{}/counterfactual_data_{}.csv".format(dir, k)
interventional0_data_truth_path = "{}/interventional0_data_truth_{}.csv".format(dir, k)
interventional1_data_truth_path = "{}/interventional1_data_truth_{}.csv".format(dir, k)
interventional0_data_gene_path = "{}/interventional0_data_gene_{}.csv".format(dir, k)
interventional1_data_gene_path = "{}/interventional1_data_gene_{}.csv".format(dir, k)
config_path = "{}/config_{}.txt".format(dir, k)
relation_path = "{}/relation_{}.txt".format(dir, k)
observational_data = np.genfromtxt(observational_data_path, skip_header=0, delimiter=',').astype(np.float32)
# counterfactual_data = np.genfromtxt(counterfactual_data_path, skip_header=0, delimiter=',').astype(np.float32)
interventional0_dataset = np.genfromtxt(interventional0_data_gene_path, skip_header=0, delimiter=',').astype(np.float32)
interventional1_dataset = np.genfromtxt(interventional1_data_gene_path, skip_header=0, delimiter=',').astype(np.float32)
interventional0_dataset_truth = np.genfromtxt(interventional0_data_truth_path, skip_header=0, delimiter=',').astype(np.float32)
interventional1_dataset_truth = np.genfromtxt(interventional1_data_truth_path, skip_header=0, delimiter=',').astype(np.float32)
config = pd.read_csv(config_path, delimiter=',').T.to_dict()[0]
# Import config
protected = config['protected'] - 1
outcome = config['outcome'] - 1
n = config['sample_size']
# Import ancestral relations
Lines = open(relation_path, 'r').readlines()
relation = {'defNonDes': [int(i) - 1 for i in Lines[0].strip().split()],
            'possDes': [int(i) - 1 for i in Lines[1].strip().split()],
            'defDes': [int(i) - 1 for i in Lines[2].strip().split()],
            'NonDes': [int(i) - 1 for i in Lines[3].strip().split()],
            'Des': [int(i) - 1 for i in Lines[4].strip().split()]}

# Data spliting
test_index = np.random.choice(np.arange(0, n), size=int(0.1 * n), replace=False)
val_index = np.random.choice(np.array(list(set(np.arange(0, n)).difference(test_index))), size=int(0.1 * n),
                             replace=False)
training_index = np.array(list(set(np.arange(0, n)).difference(np.concatenate((test_index, val_index), axis=0))))
process2index = {"Train": training_index, "Validation": val_index, "Test": test_index}
index_full, index_unaware, index_IFair = get_variable_index(d, protected, outcome, relation['NonDes'])
mode_var = {'Full': index_full, 'Unaware': index_unaware, 'Fair': index_IFair}
####################################
#### Train and predict Baselines
####################################
lambda_ = 0.0
RMSE = np.zeros([1, len(mode_var)])
Unfairness = np.zeros([1, len(mode_var)])
for i, mode in enumerate(mode_var.keys()):
    var_ind = mode_var[mode]
    if var_ind == []:
        RMSE[0, i] = np.nan
        Unfairness[0, i] = np.nan
    RMSE_cv = []
    Unfairness_cv = []
    # Train
    model = Model(dim=len(var_ind))
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0)
    train(model, optimizer, num_iters=1000, eval_every=1, lambda_=lambda_, mode=mode)
    # train_loss_full, val_loss_full = train(model, optimizer, num_iters=1000, eval_every=1, lambda_=0.0, mode=mode)
    # train_loss_list_full, train_prediction_loss_list_full, train_unfairness_list_full = train_loss_full
    # val_loss_list_full, val_prediction_loss_list_full, val_unfairness_list_full = val_loss_full
    # Predict
    best_model = Model(dim=len(var_ind))
    load_checkpoint(f'{dir}/Model_{mode}.pt', best_model)
    test_loss, test_RMSE, test_unfairness = evaluation(best_model, mode=mode, lambda_=lambda_)
    # print("test_loss: {}, test_prediction_loss: {}, test_unfairness: {}".format(test_loss, test_RMSE, test_unfairness))
    RMSE_cv.append(test_RMSE.item())
    Unfairness_cv.append(test_unfairness.item())
    print("Finished Training {} with lambda={}!".format(mode, lambda_))
    RMSE[0, i] = round(sum(RMSE_cv) / len(RMSE_cv), 3)
    Unfairness[0, i] = round(sum(Unfairness_cv) / len(Unfairness_cv), 3)

# print(RMSE)
# print(Unfairness)
# save data
if (not os.path.exists("{}/RMSE.csv".format(dir))) or (not os.path.exists("{}/Unfairness.csv".format(dir))):
    pd_header = pd.DataFrame(columns=mode_var.keys())
    pd.DataFrame(pd_header).to_csv("{}/RMSE.csv".format(dir), index=False, header=True)
    pd.DataFrame(pd_header).to_csv("{}/Unfairness.csv".format(dir), index=False, header=True)
pd.DataFrame(RMSE).to_csv("{}/RMSE.csv".format(dir), mode='a', index=False, header=False)
pd.DataFrame(Unfairness).to_csv("{}/Unfairness.csv".format(dir), mode='a', index=False, header=False)

################################
### Train the model e-IFair (with ground_truth interventionals)
################################
# lambdas = [0.0, 0.01, 0.02, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 5, 8, 10, 20]
# lambdas = [0.0, 0.5, 1, 2, 3, 5, 8, 10, 20, 30, 40]
lambdas = [0.0, 1, 3, 8, 10, 20, 40]
mode = 'e-IFair'
RMSEIF_truth = np.zeros([1, len(lambdas)])
UnfairnessIF_truth = np.zeros([1, len(lambdas)])
for i, lambda_ in enumerate(lambdas):
    RMSEIF_cv = []
    UnfairnessIF_cv = []
    model = Model(dim=len(index_full))
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0)
    train_loss_full, val_loss_full = train(model, optimizer, num_iters=1000, eval_every=1,
                                           lambda_=lambda_, mode=mode)
    best_model = Model(dim=len(index_full))
    load_checkpoint(f'{dir}/Model_{mode}.pt', best_model)
    test_loss, test_RMSE, test_unfairness = evaluation(best_model, mode=mode, lambda_=lambda_)
    # print("test_loss: {}, test_prediction_loss: {}, test_unfairness: {}".format(test_loss, test_RMSE, test_unfairness))
    RMSEIF_cv.append(test_RMSE.item())
    UnfairnessIF_cv.append(test_unfairness.item())
    print("Finished Training {} with lambda={}!".format(mode, lambda_))
    RMSEIF_truth[0, i] = round(sum(RMSEIF_cv) / len(RMSEIF_cv), 3)
    UnfairnessIF_truth[0, i] = round(sum(UnfairnessIF_cv) / len(UnfairnessIF_cv), 3)

# print(RMSEIF_truth)
# print(UnfairnessIF_truth)
# save data
if (not os.path.exists("{}/RMSEIF_truth.csv".format(dir))) or (
        not os.path.exists("{}/UnfairnessIF_truth.csv".format(dir))):
    pd_header = pd.DataFrame(columns=[str(i) + "IF" for i in lambdas])
    pd.DataFrame(pd_header).to_csv("{}/RMSEIF_truth.csv".format(dir), index=False, header=True)
    pd.DataFrame(pd_header).to_csv("{}/UnfairnessIF_truth.csv".format(dir), index=False, header=True)
pd.DataFrame(RMSEIF_truth).to_csv("{}/RMSEIF_truth.csv".format(dir), mode='a', index=False, header=False)
pd.DataFrame(UnfairnessIF_truth).to_csv("{}/UnfairnessIF_truth.csv".format(dir), mode='a', index=False, header=False)

#################################
#### Train the model e-IFair (with generated interventionals)
#################################
# lambdas = [0.0, 0.01, 0.02, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 5, 8, 10, 20]
# lambdas = [0.0, 0.5, 1, 2, 3, 5, 8, 10, 20, 30, 40]
lambdas = [0.0, 1, 3, 8, 10, 20, 40]
mode = 'e-IFair'
RMSEIF = np.zeros([1, len(lambdas)])
UnfairnessIF = np.zeros([1, len(lambdas)])
for i, lambda_ in enumerate(lambdas):
    RMSEIF_cv = []
    UnfairnessIF_cv = []
    model = Model(dim=len(index_full))
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0)
    train_loss_full, val_loss_full = train(model, optimizer, num_iters=1000, eval_every=1,
                                           lambda_=lambda_, mode=mode, ground_truth=False)
    best_model = Model(dim=len(index_full))
    load_checkpoint(f'{dir}/Model_{mode}.pt', best_model)
    test_loss, test_RMSE, test_unfairness = evaluation(best_model, mode=mode, lambda_=lambda_,
                                                       ground_truth=False)
    # print("test_loss: {}, test_prediction_loss: {}, test_unfairness: {}".format(test_loss, test_RMSE, test_unfairness))
    RMSEIF_cv.append(test_RMSE.item())
    UnfairnessIF_cv.append(test_unfairness.item())
    print("Finished Training {} with lambda={}!".format(mode, lambda_))
    RMSEIF[0, i] = round(sum(RMSEIF_cv) / len(RMSEIF_cv), 3)
    UnfairnessIF[0, i] = round(sum(UnfairnessIF_cv) / len(UnfairnessIF_cv), 3)

# print(RMSEIF)
# print(UnfairnessIF)
# save data
if (not os.path.exists("{}/RMSEIF.csv".format(dir))) or (not os.path.exists("{}/UnfairnessIF.csv".format(dir))):
    pd_header = pd.DataFrame(columns=[str(i) + "IF" for i in lambdas])
    pd.DataFrame(pd_header).to_csv("{}/RMSEIF.csv".format(dir), index=False, header=True)
    pd.DataFrame(pd_header).to_csv("{}/UnfairnessIF.csv".format(dir), index=False, header=True)
pd.DataFrame(RMSEIF).to_csv("{}/RMSEIF.csv".format(dir), mode='a', index=False, header=False)
pd.DataFrame(UnfairnessIF).to_csv("{}/UnfairnessIF.csv".format(dir), mode='a', index=False, header=False)
