import json
import os
from absl import app
from absl import flags
from absl import logging
from datetime import datetime
import numpy as np
from torch import nn, optim
import sys
sys.path.insert(1, './Code')
import Code.Image.network as network
import Code.Image.functions as functions

# ------------------------------- PARAMS --------------------------------
flags.DEFINE_integer("epochs", 1000, "epochs")
flags.DEFINE_integer("batch_size", 32, "batch size")
flags.DEFINE_integer("beta", 0, "regularization parameter")
flags.DEFINE_float("lr", 0.001, "learning rate")
flags.DEFINE_integer("count_samples", 500, "number counterfactual samples")
flags.DEFINE_integer("data_points_count", 1000, "number data points used for counterfactual samples")
flags.DEFINE_integer("init_n", 3000, "initial number data points used for generating dataset")
flags.DEFINE_integer("n_large", 60000, "n datapoints of dsprites used for matching")

# ---------------------------- INPUT/OUTPUT -----------------------------------
flags.DEFINE_string("data_dir", "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz",
                    "Directory of the input data.")
flags.DEFINE_string("output_dir", "./results/",
                    "Path to the output directory (for results).")
flags.DEFINE_string("output_name", "",
                    "Name for result folder. Use timestamp if empty.")

# ------------------------------ MISC -----------------------------------------
flags.DEFINE_integer("seed", 0, "The random seed.")
FLAGS = flags.FLAGS


# =============================================================================
# MAIN
# =============================================================================

def main(_):
    # ---------------------------------------------------------------------------
    # Directory setup, save flags, set random seed
    # ---------------------------------------------------------------------------
    if FLAGS.output_name == "":
        dir_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    else:
        dir_name = FLAGS.output_name
    out_dir = os.path.join(os.path.abspath(FLAGS.output_dir), dir_name)
    logging.info(f"Save all output to {out_dir}...")
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    FLAGS.log_dir = out_dir
    logging.get_absl_handler().use_absl_log_file(program_name="run")

    logging.info("Save FLAGS (arguments)...")
    with open(os.path.join(out_dir, 'flags.json'), 'w') as fp:
        json.dump(FLAGS.flag_values_dict(), fp, sort_keys=True, indent=2)

    logging.info(f"Set random seed {FLAGS.seed}...")

    np.random.seed(FLAGS.seed)

    results_dict = {}

    # ---------------------------------------------------------------------------
    # Load, generate and process data
    # ---------------------------------------------------------------------------

    # Load dataset
    dataset_zip = np.load(FLAGS.data_dir, allow_pickle=True, encoding="latin1")  # load dsprites dataset
    imgs = dataset_zip['imgs']
    metadata = dataset_zip['metadata'][()]

    # Define number of values per latents
    latents_sizes = metadata['latents_sizes']
    latents_bases = np.concatenate((latents_sizes[::-1].cumprod()[::-1][1:], np.array([1, ])))

    # Generate new dataset of latent variables and outcome via matching procedure
    df_total = functions.find_causal_dataset(FLAGS.init_n, len(imgs), latents_sizes, latents_bases, imgs)
    #df_total = df_total.drop_duplicates()

    # Store respective images in imgs_df, drop unobserved variables from df_total_final and normalize data
    indices_sampled = functions.latent_to_index(df_total.drop(['noise_scale', 'output'], axis=1), latents_bases)
    imgs_df = imgs[indices_sampled]
    df_total_final = df_total.reset_index()
    df_total_final_no_unobs = df_total_final.drop(['scale', 'noise_scale', 'index'], axis=1)  # scale variable is removed as unobserved
    mean, var = df_total_final_no_unobs.mean(), df_total_final_no_unobs.std()
    df_total_final_no_unobs = functions.normalize_data(df_total_final_no_unobs, mean, var)  # normalize data
    df_total_final_no_unobs['color'] = 0  # color in dsprites dataset is always white

    # Get train and test loaders
    # loader_trainer_img: training images, loader_trainer_lab: training observed latent variables,
    # loader_trainer_lab: training outcome Y
    loader_trainer_img, loader_trainer_tab, loader_trainer_lab, loader_test_img, loader_test_tab, loader_test_lab = \
        functions.get_loaders(df_total_final_no_unobs, imgs_df, FLAGS.batch_size)

    # Generate counterfactual samples
    loader_count_images, loader_count_tabular = functions.counterfactual_simulations(
        FLAGS.data_points_count, FLAGS.count_samples, mean, var, imgs, df_total_final, latents_bases)

    # ---------------------------------------------------------------------------
    # Train, test and save results
    # ---------------------------------------------------------------------------
    cnet = network.NeuralNetworkImage()
    optimizer = optim.Adam(cnet.parameters(), lr=FLAGS.lr)
    loss_function = nn.MSELoss()
    cnet, loss_vals, accuracy_loss_train, hscic_train, loss_test_vals, accuracy_loss_test, hscic_test =\
        functions.train_model(loader_trainer_img, loader_trainer_tab, loader_trainer_lab, cnet, optimizer, FLAGS.beta,
                              FLAGS.epochs, loader_test_img, loader_test_tab, loader_test_lab, loss_function)

    # Find counterfactual outcomes Yhat and metric VCF
    count_results = [functions.find_output(cnet, loader_count_images[i], loader_count_tabular[i]) for i in range(len(loader_count_images))]
    var_res = [np.var(count_results[i], ddof=1) for i in range(len(count_results))]  # variance counterfactual outcomes
    VCF = np.mean(var_res)

    # Store results in results_dict
    results = [accuracy_loss_test[-1], hscic_test[-1], VCF]
    results_dict[str(FLAGS.beta)] = results

    logging.info(f"Store results...")
    result_path = os.path.join(out_dir, "results.npz")
    np.savez(result_path, **results_dict)

    logging.info(f"DONE")


if __name__ == "__main__":
    app.run(main)
