import json
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy, MeanSquaredError
from src.experiment.utils import TimedExperimentManager
import numpy as np
from sklearn.metrics import r2_score

""":::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::"""

def default_gan_runner(trainf, data_dir, DataIterator, batch_size, nranks, epochs, path, GAN, loss_dict, loss_w_dict, hyperparams, extra_data):
    train_ds = DataIterator(json.load(open(trainf))[0:nranks], data_dir).build_dataset(batch_size, drop_remainder=True)
    opt_dict = {"G": Adam(hyperparams["lr-G"], beta_1 = hyperparams["beta1"]),
                "D": Adam(hyperparams["lr-D"], beta_1 = hyperparams["beta1"])}

    gan = GAN("imagenet", TimedExperimentManager(2.5*60*60), **hyperparams)
    gan.compile(opt_dict, loss_dict, loss_w_dict)
    gan.register_checkpoint(path + "/checkpoints", 1)

    if extra_data[0] == None:
        history = gan.fit(train_ds, epochs)
    else:
        extradf = extra_data[1]
        extraDsIter = extra_data[0]
        extra_data_dir = extra_data[2]
        extra_ds = extraDsIter(json.load(open(extradf)), extra_data_dir).build_dataset(batch_size, drop_remainder=True)
        history = gan.fit(train_ds, epochs, extra_ds)
    json.dump(history, open(path + "/train.log", "w"))
    return gan.Genr


""":::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::"""

def get_generator_results(ds_iter, generator):
    abs_errors = []
    gt_cnt = []
    pred_cnt = []
    for x,y in ds_iter:
        dmap_pred = generator.predict(x)
        p_cnt = np.sum(dmap_pred, axis=(1,2,3))
        err = np.abs(y.numpy() - p_cnt)
        abs_errors += err.tolist()
        gt_cnt += y.numpy().tolist()
        pred_cnt += p_cnt.tolist()
    return float(np.mean(abs_errors)), float(r2_score(y_pred=pred_cnt, y_true=gt_cnt))

""":::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::"""

def evaluation_runner(evalf, data_dir, DataIterator, generator, save_path):
    evald = json.load(open(evalf))
    test_ds = DataIterator(evald["test"], data_dir).build_dataset(32)
    val_ds = DataIterator(evald["val"], data_dir).build_dataset(32)

    test_mae, test_rsqr = get_generator_results(test_ds, generator)
    val_mae, val_rsqr = get_generator_results(val_ds, generator)
    result = {"test": {"mae": test_mae, "r^2": test_rsqr},
              "val" : {"mae": val_mae,  "r^2": val_rsqr}}
    json.dump(result, open(save_path, "w"))
