import json
from absl import app
from absl import flags
from absl import logging
from datetime import datetime
import torch
import os
import pandas as pd
import numpy as np
import time
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import sys
sys.path.insert(1, './Code')
import Code.Adult.functions as functions
import Code.Adult.architecture as architecture
import Code.HSCIC as hs


# ------------------------------- DUMMY PARAMS --------------------------------
flags.DEFINE_integer("epochs_U", 100, "epochs vae")
flags.DEFINE_integer("beta", 0, "beta")
flags.DEFINE_integer("epochs", 200, "epochs")
flags.DEFINE_integer("batch_size", 128, "batch size")
flags.DEFINE_float("lr_U", 0.01, "learning rate vae")
flags.DEFINE_float("lr", 0.001, "learning rate")
flags.DEFINE_integer("count_samples", 100, "number counterfactual samples")
flags.DEFINE_integer("data_points_count", 1000, "number data points used for counterfactual samples")

# ---------------------------- INPUT/OUTPUT -----------------------------------
flags.DEFINE_string("data_dir", "./pre_processed_adult_data.txt",
                    "Directory of the input data.")
flags.DEFINE_string("output_dir", "./results/",
                    "Path to the output directory (for results).")
flags.DEFINE_string("output_name", "",
                    "Name for result folder. Use timestamp if empty.")

# ------------------------------ MISC -----------------------------------------
flags.DEFINE_integer("seed", 0, "The random seed.")
FLAGS = flags.FLAGS


# local functions
# training functions of variational autoencoders for unobserved variables M, L, R
def trainM(epoch, modelM, trainloader, optimizerM):
    modelM.train()
    train_loss = 0
    train_ngll = 0
    train_kl = 0
    batch_idx = 0
    for batch_idx, data in enumerate(trainloader, 0):

        optimizerM.zero_grad()

        inputs, _ = torch.split(data, (9, 1), 1)
        inputs = inputs.float()

        recon_batch, target, mu, logvar, z, a = modelM(inputs)

        l_ngll, kl = functions.loss_functionM(recon_batch, target, mu, logvar)
        loss = l_ngll + kl

        loss.backward()

        train_loss += loss.item()

        train_ngll += l_ngll
        train_kl += kl

        optimizerM.step()

    print('Likelihood {:.6f}, KL {:.6f}'.format(
                train_ngll/(batch_idx+1), train_kl/(batch_idx+1)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(trainloader.dataset)))


def trainL(epoch, modelL, trainloader, optimizerL):
    modelL.train()
    train_loss = 0
    train_ngll = 0
    train_kl = 0
    batch_idx = 0

    for batch_idx, data in enumerate(trainloader):
        data = data

        optimizerL.zero_grad()

        inputs, _ = torch.split(data, (9, 1), 1)
        inputs = inputs.float()

        L_mu, target, mu, logvar, z, attr = modelL(inputs)
        l_ngll, kl = functions.loss_functionL(L_mu.squeeze(), target, mu, logvar)

        loss = l_ngll + kl

        loss.backward()
        train_loss += loss.item()

        train_ngll += l_ngll
        train_kl += kl

        optimizerL.step()

    print('Likelihood ({:.6f}), KL {:.6f}'.format(
          train_ngll/(batch_idx+1),  train_kl/(batch_idx+1)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(trainloader.dataset)))


def trainR(epoch, modelR, trainloader, optimizerR):
    modelR.train()
    train_loss = 0
    train_ngll = 0
    train_kl = 0
    batch_idx = 0

    for batch_idx, data in enumerate(trainloader):
        data = data

        optimizerR.zero_grad()

        inputs, _ = torch.split(data, (9, 1), 1)
        inputs = inputs.float()

        R_mu, target, mu, logvar, z, attr = modelR(inputs)


        l_ngll, kl = functions.loss_functionR(R_mu.squeeze(), target, mu, logvar)



        loss = l_ngll + kl


        loss.backward()
        train_loss += loss.item()

        train_ngll += l_ngll
        train_kl += kl

        optimizerR.step()

    print('Likelihood ({:.6f}), KL {:.6f}'.format(
          train_ngll/(batch_idx+1),  train_kl/(batch_idx+1)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(trainloader.dataset)))


# =============================================================================
# MAIN
# =============================================================================

def main(_):
    # ---------------------------------------------------------------------------
    # Directory setup, save flags, set random seed
    # ---------------------------------------------------------------------------
    if FLAGS.output_name == "":
        dir_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    else:
        dir_name = FLAGS.output_name
    out_dir = os.path.join(os.path.abspath(FLAGS.output_dir), dir_name)
    logging.info(f"Save all output to {out_dir}...")
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    FLAGS.log_dir = out_dir
    logging.get_absl_handler().use_absl_log_file(program_name="run")

    logging.info("Save FLAGS (arguments)...")
    with open(os.path.join(out_dir, 'flags.json'), 'w') as fp:
        json.dump(FLAGS.flag_values_dict(), fp, sort_keys=True, indent=2)

    logging.info(f"Set random seed {FLAGS.seed}...")

    # ---------------------------------------------------------------------------
    # Load data
    # ---------------------------------------------------------------------------
    data = pd.read_csv(FLAGS.data_dir, sep=",",
                       names=['high_income', 'male', 'married', 'higher_edu', 'managerial_occ', 'high_hours',
                              'gov_jobs', 'race', 'age', 'native_country'], skipinitialspace=True)
    data = data.drop(0, axis=0)
    data = data.reset_index()
    data = data[['age', 'male', 'native_country', 'married', 'higher_edu', 'managerial_occ',
                 'high_hours', 'gov_jobs', 'race', 'high_income']]
    data['native_country'] = [data['native_country'][i][:-1] for i in range(0, data.shape[0])]

    data = data.apply(pd.to_numeric, errors='ignore')
    data_stand = (data - data.min()) / (data.max() - data.min())
    data_stand = data_stand.apply(pd.to_numeric, errors='ignore')

    train, test = train_test_split(data_stand.to_numpy(), test_size=0.2)

    trainloader = DataLoader(train, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0)
    testloader = DataLoader(test, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0)

    modelM = architecture.VAE_M()
    optimizerM = optim.Adam(modelM.parameters(), lr=FLAGS.lr_U)

    for epoch in range(1, FLAGS.epochs_U):
        trainM(epoch, modelM, trainloader, optimizerM)

    modelL = architecture.VAE_L()
    optimizerL = optim.Adam(modelL.parameters(), lr=FLAGS.lr_U)

    for epoch in range(1, FLAGS.epochs_U):
        trainL(epoch, modelL, trainloader, optimizerL)

    modelR = architecture.VAE_R()
    optimizerR = optim.Adam(modelR.parameters(), lr=FLAGS.lr_U)

    for epoch in range(1, FLAGS.epochs_U):
        trainR(epoch, modelR, trainloader, optimizerR)

    cnet = architecture.ClassNet()
    optimizer = optim.Adam(cnet.parameters(), lr=FLAGS.lr)
    criterion = nn.CrossEntropyLoss()
    hscic = hs.HSCIC()

    # ---------------------------------------------------------------------------
    # Train cnet
    # ---------------------------------------------------------------------------

    start_time = time.time()
    for epoch in range(FLAGS.epochs):
        train_hscic = 0
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):

            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = torch.split(data, (9, 1), 1)
            inputs = inputs.float()
            labels = labels.type(torch.LongTensor)
            labz = torch.squeeze(labels)

            optimizer.zero_grad()

            outputs = cnet(inputs)
            loss = criterion(outputs, labz)
            _, predicted = torch.max(outputs.data, 1)
            hscic_value = hscic(outputs.float(),
                                torch.cat((inputs[:, 0].reshape(-1, 1), inputs[:, 1].reshape(-1, 1)), 1),
                                torch.cat((inputs[:, 2].reshape(-1, 1), inputs[:, 8].reshape(-1, 1)), 1))

            loss_model = loss + FLAGS.beta * hscic_value
            loss_model.backward()
            optimizer.step()
            train_hscic += hscic_value.item()

    print("--- %s seconds ---" % (time.time() - start_time))
    print('Finished Training')

    # ---------------------------------------------------------------------------
    # Test cnet
    # ---------------------------------------------------------------------------
    correct = 0
    total = 0
    test_hscic = []
    predicted_tot = []
    with torch.no_grad():
        for data in testloader:
            inputs = torch.narrow(data, 1, 0, 9)
            inputs = inputs.float()

            labels = torch.narrow(data, 1, 9, 1)
            labels = labels.type(torch.LongTensor)
            labz = torch.squeeze(labels)

            outputs = cnet(inputs)
            hscic_value = hscic(outputs.float(),
                                torch.cat((inputs[:, 0].reshape(-1, 1), inputs[:, 1].reshape(-1, 1)), 1),
                                torch.cat((inputs[:, 2].reshape(-1, 1), inputs[:, 8].reshape(-1, 1)), 1))

            _, predicted = torch.max(outputs.data, 1)
            predicted_tot.append(predicted)
            total += labels.size(0)
            test_hscic.append(hscic_value.item())
            correct += (predicted == labz).sum().item()

    accuracy = 100 * correct / total
    print(100 * correct / total)
    hscic = sum(test_hscic) / len(test_hscic)

    index = np.random.randint(low=0, high=len(testloader.dataset), size=FLAGS.count_samples)
    selected_testloader = DataLoader(testloader.dataset[index, ])

    # ---------------------------------------------------------------------------
    # Find counterfactual outcomes
    # ---------------------------------------------------------------------------

    test_predictions = torch.Tensor()
    test_original_predictions = torch.Tensor()
    with torch.no_grad():
        k = 0
        data_point = []
        for data in selected_testloader:
            inputs, labels = torch.split(data, (9, 1), 1)
            inputs = inputs.float()
            labels.float()

            for i in range(FLAGS.count_samples):
                inputsCF = inputs
                inputsCF[0, 0] = torch.Tensor(np.random.choice(trainloader.dataset[:, 0], size=1))  # sample age
                inputsCF[0, 1] = torch.Tensor(np.random.choice(trainloader.dataset[:, 1], size=1))  # sample gender

                # Using trained VAEs modelM, modelL, modelR for sampling unobserved variables.
                # Estimating Ycf using a Monte-Carlo approach with n=500
                mu, logvar = modelM.encode(inputs)  # get the posterior q(M | V*)

                # sample from the posterior
                parents_M = inputsCF[:, :3]
                inputs_expand = inputsCF
                post_M = torch.normal(mu, torch.sqrt(torch.exp(logvar))).view(1, -1)

                for n in range(499):
                    inputs_expand = torch.cat((inputs_expand, inputsCF), 0)
                    parents_M = torch.cat((parents_M, inputsCF[:, :3]), dim=0)
                    post_M = torch.cat((post_M, torch.normal(mu, torch.exp(logvar)).view(1, -1)), dim=0)

                # decode sample + posterior samples & sample from CF distribution
                M_CF = torch.max(modelM.decode(torch.cat((parents_M, post_M), dim=1)), 1)[1]

                mu_L, logvar_L = modelL.encode(inputs)  # get the posterior q(L | V*)

                # sample from the posterior
                post_L = torch.normal(mu_L, torch.exp(logvar_L)).view(1, -1)
                for n in range(499):
                    post_L = torch.cat((post_L, torch.normal(mu_L, torch.exp(logvar_L)).view(1, -1)), dim=0)

                parents_L = torch.cat((parents_M, M_CF.view(-1, 1)), 1)

                # decode sample + posterior samples
                mean_L = modelL.decode(torch.cat((parents_L, post_L), dim=1))

                # sample from CF distribution
                L_CF = torch.squeeze(mean_L)

                mu_R, logvar_R = modelR.encode(inputs)  # get the posterior q(R | V*)

                # sample from the posterior
                post_R = torch.normal(mu_R, torch.sqrt(torch.exp(logvar_R))).view(1, -1)
                for n in range(499):
                    post_R = torch.cat((post_R, torch.normal(mu_R, torch.sqrt(torch.exp(logvar_R))).view(1, -1)), dim=0)

                parents_R = torch.cat((parents_L, L_CF.view(-1, 1)), 1)

                # decode sample + posterior samples
                mean_R = modelR.decode(torch.cat((parents_R, post_R), dim=1))

                # sample from CF distribution
                R_CF = torch.squeeze(mean_R)

                # merge all the CF data together
                inputs_expand[:, 3] = M_CF
                inputs_expand[:, 4] = L_CF
                inputs_expand[:, 5:8] = R_CF

                # make a prediction using cnet
                pred = F.softmax(cnet(inputs_expand))[:, 1].mean()

                test_predictions = torch.cat((test_predictions, pred.view(-1)))

                data_point.append(k)

                original_predictions = F.softmax(cnet(inputs))[:, 1].mean()

                test_original_predictions = torch.cat((test_original_predictions, original_predictions.view(-1)))

            k += 1

        print("--- %s seconds ---" % (time.time() - start_time))

    results = pd.concat([pd.DataFrame(np.array(test_predictions).reshape(-1, 1)),
                         pd.DataFrame(np.array(data_point).reshape(-1, 1))], axis=1)
    results.columns = ['CF_prediction', 'data_point_number']

    grouped_results = results.groupby(['data_point_number'], as_index=False).std()
    grouped_results_vcf = grouped_results.mean()

    results_dict = {}
    results = [accuracy, hscic, grouped_results_vcf['CF_prediction']]

    results_dict[str(FLAGS.beta)] = results
    cf_results = {}
    cf_results[str(FLAGS.beta)] = grouped_results

    logging.info(f"Store results...")
    result_path = os.path.join(out_dir, "results.npz")

    np.savez(result_path, **results_dict)

    logging.info(f"DONE")


if __name__ == "__main__":
    app.run(main)
