#%%%
import pandas as pd

def get_dummy(df, features):
    df = pd.get_dummies(df, columns=features, drop_first=True) 
    return df

# ------path setting------
path = "PATH_TO_DATA"
dir = path+"/data"
data_path = "{}/credit_preprocessed.csv".format(dir)

# ------file read in and setting up------
data = pd.read_csv(data_path, delimiter=',',header=0)

#%%
# ------fit values with parents------
import pickle
from scipy.stats import chi2_contingency


relation_path = '{}/bp_save.txt'.format(dir)
last_inputs = []

pairs,protected,inputs, roots, last_input = [],'',[],[], None
with open(relation_path, 'r') as f:
    n, protected = [word for word in f.readline().strip().split(' ')]

    for i in range(int(n)):
        b,p = f.readline().split('\t')
        b = [word for word in b.strip().split(' ')[1:]]
        p = [word for word in p.strip().split(' ')[1:]]

        if i < 6:  
            pairs.append([b, p])
            wfile = "{}/models/model_{}".format(dir, i)
            print("...........Start model_{}...........".format(i))

            bucket = [data[a] for a in b]
            parent = [data[a] for a in p]

            ctgt = pd.crosstab(bucket, parent, margins=True)

            assert(chi2_contingency(ctgt.iloc[:-1, :-1])[1] < 0.01)

            cur_model = dict()
            attris = ctgt.columns.to_list()[:-1]
            for attri in attris:
                list_value = ctgt.index.to_list()[:-1]

                cur_nums = ctgt[attri].to_numpy()
                list_prob = cur_nums[:-1] / cur_nums[-1]
                cur_model[attri] = (list_value, list_prob)


            file = open(wfile+'.mdl', 'wb')
            pickle.dump(cur_model, file)
            file.close()
            print("...........model_{} done!...........".format(i))
        else:
            last_inputs.append([b, p])


#%%
# generate data
import pickle
import random
import numpy as np

# n_sample = len(data)
nob0, nob1, d = len(data[data.Age==23]), len(data[data.Age==30]), len(data.columns)
gened_data0 = pd.DataFrame(np.zeros((nob0,d)))
gened_data1 = pd.DataFrame(np.zeros((nob1,d)))

gened_data0.columns = data.columns
gened_data1.columns = data.columns

# generate intervention variables
gened_data0[protected] = 23 
gened_data1[protected] = 30 

#%%
# generate conditional probability
SEED = 532
random.seed(SEED)

for i in range(len(pairs)):
    outputs, inputs= pairs[i]

    input0 = gened_data0[inputs]
    input1 = gened_data1[inputs]

    wfile = "{}/models/model_{}".format(dir, i)
    file = open(wfile+'.mdl', 'rb')
    model = pickle.load(file)
    file.close()

    if i < 6:
        for j in range(nob0):
            gened_data0.loc[j,outputs] = input0.iloc[j].apply(
                    lambda attri: random.choices(model[attri][0], model[attri][1]))[0][0]
        for j in range(nob1):
            gened_data1.loc[j,outputs] = input1.iloc[j].apply(
                    lambda attri: random.choices(model[attri][0], model[attri][1]))[0][0]
    
    else:    
        for j in range(nob0):
            gened_data0.loc[j,outputs] = random.choices(model[tuple(input0.iloc[j].values)][0], 
                    model[tuple(input0.iloc[j].values)][0])
        for j in range(nob1):
            gened_data1.loc[j,outputs] = random.choices(model[tuple(input1.iloc[j].values)][0], 
                    model[tuple(input1.iloc[j].values)][0])

    print("...........model_{} used!...........".format(i))


#%%
# fitting penultimate models
data_origin = data.copy()

# penultimate: prepare data
import torch
from torch.utils.data import DataLoader
import numpy as np
import ctgnetwork as mdn
from sklearn.preprocessing import StandardScaler
from utils import save_checkpoint, load_checkpoint

num_vars = ['Income', 'Loan_Amount']
cat_vars = ['Loan_Grade']  

scalers = []
for i, feature in enumerate(num_vars):
    scaler = StandardScaler()
    scaler.fit(data[feature].values.reshape(-1,1))
    data[feature] = scaler.transform(data[feature].values.reshape(-1,1))
    scalers.append(scaler)

data = pd.get_dummies(data, columns=cat_vars, drop_first=False, prefix=cat_vars)
last_inputs[0][1] = [name for name in data.columns if any(map(lambda x: name.startswith(x), last_inputs[0][1]))]

#%%
# last 2: training
RANDOM_SEED = 532   # 42
EPOCHS = 20   # 200
LR = 1e-2
BATCH_SIZE = 1024
NUM_GAUSSIAN = 4 # 4
NUM_HIDDEN = 8 # 16
STEP_SIZE = 10
GAMMA = 0.5

labels = []
for b,p in last_inputs:

    torch.manual_seed(RANDOM_SEED) 
    input =  mdn.MyDataset(data, b, p)

    train_size, test_size = int(len(data) * 0.8), len(data) - int(len(data) * 0.8)

    trainData, tempData = torch.utils.data.random_split(input, [train_size, test_size])
    valData, testData = torch.utils.data.random_split(tempData, [test_size//2, test_size - test_size//2])
    trainSet = DataLoader(dataset=trainData, shuffle=True, drop_last=False, batch_size=BATCH_SIZE)
    valSet = DataLoader(dataset=valData, shuffle=False, drop_last=False, batch_size=len(valData))
    testSet = DataLoader(dataset=testData, shuffle=False, drop_last=False, batch_size=len(testData))

    wfile = "{}/models/model_{}".format(dir, b[0])

    model = mdn.MDN(len(p), len(b), NUM_GAUSSIAN, NUM_HIDDEN)
    loss_fn = mdn.mdn_loss

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=LR, weight_decay=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)

    print("...........Start model_{}...........".format(b[0]))
    best_val_loss = float("Inf")

    for epoch in range(EPOCHS):
    # for epoch in range(1):
        train_loss, train_count = 0, 0
        for x, y in trainSet:
            x, y = x.to(device), y.to(device)

            train_count += 1
            model.zero_grad()
            output = model(x)
            loss = loss_fn(output, y)

            train_loss += loss.detach().numpy()
            loss.backward()
            optimizer.step()

            # val_loss, val_count = 0, 0
            # with torch.no_grad():
            #     for x, y in valSet:
            #         x, y = x.to(device), y.to(device)
            #         val_count += 1
            #         output = model(x)
            #         val_loss += loss_fn(output, y).detach().numpy()
                
            #     cur_val_loss = val_loss / val_count
            #     if best_val_loss > cur_val_loss:
            #         best_val_loss = cur_val_loss
            #         save_checkpoint('{}/Model_{}.pt'.format(dir,b[0]), model, best_val_loss)
            # print('Epoch: {}, TrainLoss: {:.3f}, ValLoss: {:.3f}'.format(epoch + 1, train_loss / train_count,
            #                                                         val_loss / val_count))

        val_loss, val_count = 0, 0
        if i % 1 == 0: 
            with torch.no_grad():
                for x, y in valSet:
                    x, y = x.to(device), y.to(device)
                    val_count += 1
                    output = model(x)
                    val_loss += loss_fn(output, y).detach().numpy()
                
                cur_val_loss = val_loss / val_count
                if best_val_loss > cur_val_loss:
                    best_val_loss = cur_val_loss
                    save_checkpoint('{}/Model_{}.pt'.format(dir,b[0]), model, best_val_loss)

            print('Epoch: {}, TrainLoss: {:.3f}, ValLoss: {:.3f}'.format(epoch + 1, train_loss / train_count,
                                                                    val_loss / val_count))
        
        scheduler.step()


    best_model = mdn.MDN(len(p), len(b), NUM_GAUSSIAN, NUM_HIDDEN)

    load_checkpoint('{}/Model_{}.pt'.format(dir,b[0]), best_model)

    file = open(wfile+'.mdl', 'wb')
    pickle.dump(best_model, file)
    file.close()
    print("...........model_{} done!...........".format(b[0]))

#%%
# generate the last column
copy0 = gened_data0.copy()
copy1 = gened_data1.copy()

# normalize Income
feature, scaler = num_vars[0], scalers[0]
copy0[feature] = scaler.transform(copy0[feature].values.reshape(-1,1))
copy1[feature] = scaler.transform(copy1[feature].values.reshape(-1,1))

#%%
copy0 = pd.get_dummies(copy0, columns=cat_vars, drop_first=False, prefix=cat_vars)
copy1 = pd.get_dummies(copy1, columns=cat_vars, drop_first=False, prefix=cat_vars)


#%%
# generate!
for b,p in last_inputs:

    wfile = "{}/models/model_{}".format(dir, b[0])
    file = open(wfile+'.mdl', 'rb')
    model = pickle.load(file)
    file.close()

    input0 = mdn.MyDataset(copy0, p)
    input1 = mdn.MyDataset(copy1, p)

    dataSet0 = DataLoader(dataset=input0, shuffle=False, drop_last=False, batch_size=len(input0))
    dataSet1 = DataLoader(dataset=input1, shuffle=False, drop_last=False, batch_size=len(input1))

    sample = mdn.sample
    # sample = sample

    with torch.no_grad():
        for x in dataSet0:
            outputs = model(x)
            gened_data0[b] = sample(outputs).numpy().reshape(-1,1)
        
        for x in dataSet1:
            outputs = model(x)
            gened_data1[b] = sample(outputs).numpy().reshape(-1,1)
            
    print("model_{} used!".format(b[0]))


#%%

# 跳过 Income
for feature, scaler in zip(num_vars[1:], scalers[1:]):
    min_ = data[feature].min()
    max_ = data[feature].max()
    gened_data0[feature] = scaler.inverse_transform(np.clip(gened_data0[feature].values.reshape(-1,1), min_, max_))
    gened_data1[feature] = scaler.inverse_transform(np.clip(gened_data1[feature].values.reshape(-1,1), min_, max_))


#%%
for feature in num_vars[1:]:
    label0 = list(data_origin[data_origin.Age==23].Loan_Amount.value_counts().keys())
    cur_col0 = torch.tensor(gened_data0.Loan_Amount).unsqueeze(1)

    cur_col0 = cur_col0.expand(cur_col0.shape[0], len(label0))
    label0 = torch.tensor(label0).expand(cur_col0.shape[0], len(label0))

    diff0 = torch.argmin(torch.abs(cur_col0 - label0), dim=1).numpy()
    # gened_data0[feature] = label0[0][diff0].reshape(-1,1)
    gened_data0.loc[:,feature] = label0[0][diff0].numpy()


    label1 = list(data_origin[data_origin.Age==30].Loan_Amount.value_counts().keys())
    cur_col1 = torch.tensor(gened_data1.Loan_Amount).unsqueeze(1)

    cur_col1 = cur_col1.expand(cur_col1.shape[0], len(label1))
    label1 = torch.tensor(label1).expand(cur_col1.shape[0], len(label1))
    diff1 = torch.argmin(torch.abs(cur_col1 - label1), dim=1).numpy()
    # gened_data1[feature] = label0[0][diff1].reshape(-1,1)
    gened_data1.loc[:,feature] = label1[0][diff1].numpy()


#%%
# generate the last column
gened_data0['Loan_Percent_Income'] = (gened_data0['Loan_Amount'] / gened_data0['Income']).round(2)
gened_data1['Loan_Percent_Income'] = (gened_data1['Loan_Amount'] / gened_data1['Income']).round(2)


#%%
# save data
gened_data0.to_csv("{}/interventional0_data_gened.csv".format(dir), sep=',', index=False, header=True)
gened_data1.to_csv("{}/interventional1_data_gened.csv".format(dir), sep=',', index=False, header=True)

