#%%
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
import argparse
from utils import save_checkpoint, load_checkpoint
import os.path


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class Model(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.net = nn.Sequential(
            nn.Linear(self.dim, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, x):
        y = self.net(x)
        return y


def loss_func(y_pred, y0_pred, y1_pred, y, lambda_):
    '''loss = prediction loss + lambda * unfairness'''
    prediction_loss = F.mse_loss(y_pred, y)   
    unfairness = torch.abs(torch.mean(y0_pred) - torch.mean(y1_pred))  
    loss = prediction_loss + lambda_ * unfairness
    return prediction_loss, unfairness, loss


def get_variable_index(nodes, protected, outcome, NonDes):
    '''protected, outcome, NonDes are stings
       nodes: [node1, node2, node3, ...]
    '''
    index_full = [i for i in nodes if i != outcome]
    index_unaware = [i for i in nodes if i != protected and i != outcome]
    index_IFair = NonDes
    return index_full, index_unaware, index_IFair


def data_preprocess(process2index, ground_truth=True, process="Train"):
    index = process2index[process]  

    real0 = interventional0_dataset_real.iloc[index[2],]
    real1 = interventional1_dataset_real.iloc[index[3],]
    data = pd.concat([real0, real1], axis=0)
    data = data.reset_index(drop=True)

    if ground_truth == True or (process == "Test" and ground_truth == False):
        interventional0_data = real0
        interventional1_data = real1
    else:
        interventional0_data = interventional0_dataset.iloc[index[0],]
        interventional1_data = interventional1_dataset.iloc[index[1],]

    return (data, interventional0_data, interventional1_data)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train(model, optimizer, process2index, num_iters=100, eval_every=10, lambda_=0.0, mode="Full", ground_truth=True):
    training_data, training_interventional0_data, training_interventional1_data = data_preprocess(process2index, ground_truth, process="Train")
    val_data, val_interventional0_data, val_interventional1_data = data_preprocess(process2index, ground_truth, process="Validation")
    
    model.to(device)

    if mode == "Full" or mode == "e-IFair":
        x = torch.from_numpy(training_data[index_full].values)
        x0 = torch.from_numpy(training_interventional0_data[index_full].values)
        x1 = torch.from_numpy(training_interventional1_data[index_full].values)
        x_val = torch.from_numpy(val_data[index_full].values)
        x0_val = torch.from_numpy(val_interventional0_data[index_full].values)
        x1_val = torch.from_numpy(val_interventional1_data[index_full].values)
    if mode == "Unaware":
        x = torch.from_numpy(training_data[index_unaware].values)
        x0 = torch.from_numpy(training_interventional0_data[index_unaware].values)
        x1 = torch.from_numpy(training_interventional1_data[index_unaware].values)
        x_val = torch.from_numpy(val_data[index_unaware].values)
        x0_val = torch.from_numpy(val_interventional0_data[index_unaware].values)
        x1_val = torch.from_numpy(val_interventional1_data[index_unaware].values)
    if mode == "Fair":
        x = torch.from_numpy(training_data[index_IFair].values)
        x0 = torch.from_numpy(training_interventional0_data[index_IFair].values)
        x1 = torch.from_numpy(training_interventional1_data[index_IFair].values)
        x_val = torch.from_numpy(val_data[index_IFair].values)
        x0_val = torch.from_numpy(val_interventional0_data[index_IFair].values)
        x1_val = torch.from_numpy(val_interventional1_data[index_IFair].values)
    y = torch.unsqueeze(torch.from_numpy(training_data[outcome].values), 1)
    y_val = torch.unsqueeze(torch.from_numpy(val_data[outcome].values), 1)

    train_loss_list = []
    train_prediction_loss_list = []
    train_unfairness_list = []
    val_loss_list = []
    val_prediction_loss_list = []
    val_unfairness_list = []

    # training loop
    best_val_loss = float("Inf")
    model.train()

    x, x0, x1 = x.to(device), x0.to(device), x1.to(device)
    for i in range(num_iters):
        y_pred = model(x)
        y0_pred = model(x0)
        y1_pred = model(x1)
        train_prediction_loss, train_unfairness, train_loss = loss_func(y_pred, y0_pred, y1_pred, y, lambda_=lambda_)
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        train_loss_list.append(train_loss.item())
        train_prediction_loss_list.append(train_prediction_loss.item())
        train_unfairness_list.append(train_unfairness.item())
        # evaluation step
        x_val, x0_val, x1_val = x_val.to(device), x0_val.to(device), x1_val.to(device)
        if i % eval_every == 0:
            model.eval()   
            with torch.no_grad():
                y_val_pred = model(x_val)
                y0_val_pred = model(x0_val)
                y1_val_pred = model(x1_val)
                val_prediction_loss, val_unfairness, val_loss = loss_func(y_val_pred, y0_val_pred, y1_val_pred, y_val,
                                                                          lambda_=lambda_)
                val_loss_list.append(val_loss.item())
                val_prediction_loss_list.append(val_prediction_loss.item())
                val_unfairness_list.append(val_unfairness.item())
                # print progress
                #                 print(i, train_loss.item(), train_prediction_loss.item(), train_unfairness.item(), val_loss.item(), val_prediction_loss.item(), val_unfairness.item())
                # print(
                #     'Step [{}/{}], Train Loss: {:.4f}, Train Prediction Loss: {:.4f}, Train Unfairness: {:.4f}, Valid Loss: {:.4f}, Valid Prediction Loss: {:.4f}, Valid Unfairness: {:.4f},'
                #     .format(i, num_iters, train_loss, train_prediction_loss, train_unfairness, val_loss,
                #             val_prediction_loss, val_unfairness))
                # checkpoint
                if best_val_loss > val_loss:
                    best_val_loss = val_loss
                    save_checkpoint(f'{dir}/Model_{mode}.pt', model, best_val_loss)
    train_loss_result = (train_loss_list, train_prediction_loss_list, train_unfairness_list)
    val_loss_result = (val_loss_list, val_prediction_loss_list, val_unfairness_list)
    # print("Finished Training {} with lambda={}!".format(mode, lambda_))
    return train_loss_result, val_loss_result


def evaluation(model, process2index, mode="Full", lambda_=0.0, ground_truth=True):
    test_data, test_interventional0_data, test_interventional1_data = data_preprocess(process2index, ground_truth, process="Test")
    if mode == "Full" or mode == "e-IFair":
        x_test = torch.from_numpy(test_data[index_full].values)
        x0_test = torch.from_numpy(test_interventional0_data[index_full].values)
        x1_test = torch.from_numpy(test_interventional1_data[index_full].values)
    if mode == "Unaware":
        x_test = torch.from_numpy(test_data[index_unaware].values)
        x0_test = torch.from_numpy(test_interventional0_data[index_unaware].values)
        x1_test = torch.from_numpy(test_interventional1_data[index_unaware].values)
    if mode == "Fair":
        x_test = torch.from_numpy(test_data[index_IFair].values)
        x0_test = torch.from_numpy(test_interventional0_data[index_IFair].values)
        x1_test = torch.from_numpy(test_interventional1_data[index_IFair].values)
    y_test = torch.unsqueeze(torch.from_numpy(test_data[outcome].values), 1)
    model.eval()
    with torch.no_grad():

        y_test_pred = model(x_test)
        y0_test_pred = model(x0_test)
        y1_test_pred = model(x1_test)

        test_prediction_loss, test_unfairness, test_loss = loss_func(y_test_pred, y0_test_pred, y1_test_pred, y_test,
                                                                     lambda_=lambda_)
  
        test_loss_result = (test_loss, torch.sqrt(test_prediction_loss), test_unfairness)
    return test_loss_result


#%%
from sklearn.model_selection import KFold

# # ------Parameters setting------
seed = 42         
set_random_seed(seed)

BASE_ITER = 1000  
GENE_ITER = 1000

path = "PATH_TO_DATA"
dir = path+"/data" 
observational_data_path = "{}/observation_data.csv".format(dir)
interventional0_data_gene_path = "{}/interventional0_data_gene.csv".format(dir)
interventional1_data_gene_path = "{}/interventional1_data_gene.csv".format(dir)
interventional0_data_real_path = "{}/interventional0_data_real.csv".format(dir)
interventional1_data_real_path = "{}/interventional1_data_real.csv".format(dir)

observational_data = pd.read_csv(observational_data_path, header=0, delimiter=',').astype(np.float32)
interventional0_dataset = pd.read_csv(interventional0_data_gene_path, header=0, delimiter=',').astype(np.float32)
interventional1_dataset = pd.read_csv(interventional1_data_gene_path, header=0, delimiter=',').astype(np.float32)
interventional0_dataset_real = pd.read_csv(interventional0_data_real_path, header=0, delimiter=',').astype(np.float32)
interventional1_dataset_real = pd.read_csv(interventional1_data_real_path, header=0, delimiter=',').astype(np.float32)

protected = 'sex'
outcome = 'Grade'
Des = ['Walc', 'Dalc', 'studytime', 'goout']   
NonDes = list(set(observational_data.columns) - set([protected, outcome] + ['failure'] + Des))

n = np.array([len(interventional0_dataset), len(interventional1_dataset), len(observational_data)]) 
nodes = interventional0_dataset.columns

index_full, index_unaware, index_IFair = get_variable_index(nodes, protected, outcome, NonDes)
mode_var = {'Full': index_full, 'Unaware': index_unaware, 'Fair': index_IFair}
print('index_IFair: ', index_IFair)

# ------Data Splitting------
skf = KFold(n_splits=5, shuffle=True, random_state=seed)
skf.get_n_splits(observational_data)

split_real0 = skf.split(interventional0_dataset_real)
split_real1 = skf.split(interventional1_dataset_real)
split_gene0 = skf.split(interventional0_dataset)
split_gene1 = skf.split(interventional1_dataset)

wrapper = zip(split_gene0, split_gene1, split_real0, split_real1)
for i, ((tr0,tt0), (tr1,tt1), (tr2,tt2), (tr3,tt3)) in enumerate(wrapper):
    if i == 3:
        for t2_test, t2_val, t3_test, t3_val in [[tt2[:len(tt2)//2], tt2[len(tt2)//2:], tt3[:len(tt3)//2], tt3[len(tt3)//2:]],
                                                [tt2[len(tt2)//2:], tt2[:len(tt2)//2], tt3[len(tt3)//2:], tt3[:len(tt3)//2]]]:

            test_index = [t2_test, t3_test, t2_test, t3_test]
            val_index = [tt0, tt1, t2_val, t3_val]
            train_index = [tr0, tr1, tr2, tr3]

            process2index = {"Train": train_index, "Validation": val_index, "Test": test_index}

            #%%
            ####################################
            #### Train and predict Baselines
            ####################################
            lambda_ = 0.0
            RMSE = np.zeros([1, len(mode_var)])
            Unfairness = np.zeros([1, len(mode_var)])
            for i, mode in enumerate(mode_var.keys()):
                var_ind = mode_var[mode]
                if var_ind == []:
                    RMSE[0, i] = np.nan
                    Unfairness[0, i] = np.nan
                RMSE_cv = []
                Unfairness_cv = []
                # Train
                model = Model(dim=len(var_ind))
                optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0)
                train(model, optimizer, process2index, num_iters=BASE_ITER, eval_every=1, lambda_=lambda_, mode=mode)
                # train_loss_full, val_loss_full = train(model, optimizer, num_iters=1000, eval_every=1, lambda_=0.0, mode=mode)
                # train_loss_list_full, train_prediction_loss_list_full, train_unfairness_list_full = train_loss_full
                # val_loss_list_full, val_prediction_loss_list_full, val_unfairness_list_full = val_loss_full
                # Predict
                best_model = Model(dim=len(var_ind))
                load_checkpoint(f'{dir}/Model_{mode}.pt', best_model)
                test_loss, test_RMSE, test_unfairness = evaluation(best_model, process2index, mode=mode, lambda_=lambda_)
                
                print("test_loss: {}, test_prediction_loss: {}, test_unfairness: {}".format(test_loss, test_RMSE, test_unfairness))
                
                RMSE_cv.append(test_RMSE.item())
                Unfairness_cv.append(test_unfairness.item())
                print("Finished Training {} with lambda={}!".format(mode, lambda_))
                RMSE[0, i] = round(sum(RMSE_cv) / len(RMSE_cv), 3)
                Unfairness[0, i] = round(sum(Unfairness_cv) / len(Unfairness_cv), 3)

            # save data
            if (not os.path.exists("{}/RMSE.csv".format(dir))) or (not os.path.exists("{}/Unfairness.csv".format(dir))):
                pd_header = pd.DataFrame(columns=mode_var.keys())
                pd.DataFrame(pd_header).to_csv("{}/RMSE.csv".format(dir), index=False, header=True)
                pd.DataFrame(pd_header).to_csv("{}/Unfairness.csv".format(dir), index=False, header=True)
            pd.DataFrame(RMSE).to_csv("{}/RMSE.csv".format(dir), mode='a', index=False, header=False)
            pd.DataFrame(Unfairness).to_csv("{}/Unfairness.csv".format(dir), mode='a', index=False, header=False)


            #%%
            #################################
            #### Train the model e-IFair (with generated interventionals)
            #################################
            # lambdas = [0.0, 0.5, 1, 2, 3, 5, 8, 10, 20, 30, 40]
            lambdas = [0.0, 1, 3, 8, 10, 20, 40]
            mode = 'e-IFair'
            RMSEIF = np.zeros([1, len(lambdas)])
            UnfairnessIF = np.zeros([1, len(lambdas)])
            for i, lambda_ in enumerate(lambdas):
                RMSEIF_cv = []
                UnfairnessIF_cv = []
                model = Model(dim=len(index_full))
                optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0)
                train_loss_full, val_loss_full = train(model, optimizer, process2index, num_iters=GENE_ITER, eval_every=1,
                                                    lambda_=lambda_, mode=mode, ground_truth=False)
                best_model = Model(dim=len(index_full))
                load_checkpoint(f'{dir}/Model_{mode}.pt', best_model)
                test_loss, test_RMSE, test_unfairness = evaluation(best_model, process2index, mode=mode, lambda_=lambda_,
                                                                ground_truth=False) 
                print("test_loss: {}, test_prediction_loss: {}, test_unfairness: {}".format(test_loss, test_RMSE, test_unfairness))
                
                RMSEIF_cv.append(test_RMSE.item())
                UnfairnessIF_cv.append(test_unfairness.item())
                print("Finished Training {} with lambda={}!".format(mode, lambda_))
                RMSEIF[0, i] = round(sum(RMSEIF_cv) / len(RMSEIF_cv), 3)
                UnfairnessIF[0, i] = round(sum(UnfairnessIF_cv) / len(UnfairnessIF_cv), 3)

            # save data
            if (not os.path.exists("{}/RMSEIF.csv".format(dir))) or (not os.path.exists("{}/UnfairnessIF.csv".format(dir))):
                pd_header = pd.DataFrame(columns=[str(i) + "IF" for i in lambdas])
                pd.DataFrame(pd_header).to_csv("{}/RMSEIF.csv".format(dir), index=False, header=True)
                pd.DataFrame(pd_header).to_csv("{}/UnfairnessIF.csv".format(dir), index=False, header=True)
            pd.DataFrame(RMSEIF).to_csv("{}/RMSEIF.csv".format(dir), mode='a', index=False, header=False)
            pd.DataFrame(UnfairnessIF).to_csv("{}/UnfairnessIF.csv".format(dir), mode='a', index=False, header=False)

