import json
import os

from absl import app
from absl import flags
from absl import logging

from datetime import datetime
from torch import nn, optim
import numpy as np
import sys
sys.path.insert(1, './Code')
import Code.Tabular.functions as functions
import Code.Tabular.architecture as architecture

# ------------------------------- PARAMS --------------------------------
flags.DEFINE_integer("epochs", 100, "epochs")
flags.DEFINE_integer("batch_size", 256, "batch size")
flags.DEFINE_integer("number_samples", 10000, "number samples")
flags.DEFINE_integer("beta", 0, "regularization parameter")
flags.DEFINE_integer("sim", 1, "scenario simulation")
flags.DEFINE_float("lr", 0.001, "learning rate")
flags.DEFINE_integer("dim_h", 20, "hidden dimension")
flags.DEFINE_integer("dim_input", 3, "input dimension")
flags.DEFINE_integer("nh", 8, "number hidden layers")
flags.DEFINE_integer("count_samples", 500, "number counterfactual samples")
flags.DEFINE_integer("data_points_count", 1000, "number data points used for counterfactual samples")

# ---------------------------- INPUT/OUTPUT -----------------------------------
flags.DEFINE_string("output_dir", "./results/",
                    "Path to the output directory (for results).")
flags.DEFINE_string("output_name", "",
                    "Name for result folder. Use timestamp if empty.")

# ------------------------------ MISC -----------------------------------------
flags.DEFINE_integer("seed", 0, "The random seed.")
FLAGS = flags.FLAGS

# =============================================================================
# MAIN
# =============================================================================


def main(_):
    # ---------------------------------------------------------------------------
    # Directory setup, save flags, set random seed
    # ---------------------------------------------------------------------------
    if FLAGS.output_name == "":
        dir_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    else:
        dir_name = FLAGS.output_name

    out_dir = os.path.join(os.path.abspath(FLAGS.output_dir), dir_name)
    logging.info(f"Save all output to {out_dir}...")
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    FLAGS.log_dir = out_dir
    logging.get_absl_handler().use_absl_log_file(program_name="run")

    logging.info("Save FLAGS (arguments)...")
    with open(os.path.join(out_dir, 'flags.json'), 'w') as fp:
        json.dump(FLAGS.flag_values_dict(), fp, sort_keys=True, indent=2)

    logging.info(f"Set random seed {FLAGS.seed}...")
    np.random.seed(FLAGS.seed)

    # ---------------------------------------------------------------------------
    # Load data
    # ---------------------------------------------------------------------------
    data = functions.simulation_data(FLAGS.number_samples, FLAGS.sim)  # [X, Z, A, Y]
    train_dataloader, test_dataloader, _, _, _ = functions.data_processing(data, FLAGS.batch_size)

    # generate counterfactual samples
    # in each element of count_loader_tot, count_samples counterfactual datapoints are generated
    count_loader_tot = functions.counterfactual_simulations(FLAGS.data_points_count, FLAGS.count_samples, data, FLAGS.sim)

    # initialize results dictionary
    results_dict = {}

    # ---------------------------------------------------------------------------
    # Train and save test results
    # ---------------------------------------------------------------------------
    loss_function = nn.MSELoss()
    cnet = architecture.Model(dim_in=FLAGS.dim_input, nh=FLAGS.nh, dim_h=FLAGS.dim_h)
    optimizer = optim.Adam(cnet.parameters(), lr=FLAGS.lr)
    cnet, loss_tot_train, loss_acc_train, loss_hscic_train, loss_tot_test, loss_acc_test, loss_hscic_test = \
        functions.train_model(train_dataloader, optimizer, cnet, loss_function, FLAGS.beta, FLAGS.epochs,
                              test_dataloader)

    # find counterfactuals outcomes
    # in each element of count_loader_tot, the variance of the counterfactual outcomes of the
    # count_samples datapoints is stored in var_res
    count_results = [functions.find_output(cnet, count_loader_tot[i]) for i in range(len(count_loader_tot))]
    var_res = [np.var(count_results[i], ddof=1) for i in range(len(count_results))]  # variance CF outcomes

    # store results of accuracy, hscic, VCF (mean of var_res) in results_dict
    results = [loss_acc_test[-1], loss_hscic_test[-1], np.mean(var_res)]
    results_dict[str(FLAGS.beta)] = results

    logging.info(f"Store results...")
    result_path = os.path.join(out_dir, "results.npz")
    np.savez(result_path, **results_dict)

    logging.info(f"DONE")


if __name__ == "__main__":
    app.run(main)
