from tensorflow.keras.losses import BinaryCrossentropy, MeanSquaredError
from datetime import datetime
from pathlib  import Path
import argparse
import numpy as np
from src.dataset import *
from src.models.gan import *
from src.models.other import *

def print_preamble(exp_name, ds_name, nranks, batch_size, epochs, hyperparams_dict):
    title_str = "{}, {}, {}, {}, {}".format(exp_name, ds_name, nranks, batch_size, epochs)
    print("+-{}-+".format("-"*len(title_str)), flush=True)
    print("| {} |".format(title_str), flush=True)
    print("+-{}-+".format("-"*len(title_str)), flush=True)
    timestamp = "Start:    {}".format(datetime.now().strftime("%b %d, %Y - %H:%M:%S")).ljust(len(title_str))
    print("| {} |".format(timestamp), flush=True)
    max_len = max([len(key) for key in hyperparams_dict.keys()])
    for key, val in hyperparams_dict.items():
        param = "{}:    {}".format(key.ljust(max_len), val).ljust(len(title_str))
        print("| {} |".format(param), flush=True)
    print("+-{}-+".format("-"*len(title_str)), flush=True)

def make_experiment_paths(exp_dir, exp_name, ds_name, nranks, exp_id, hyperparams_dict):
    current_time = datetime.now().strftime("%Y-%b-%d_%H:%M:%S")
    base_model_path = "{}/{}/{}/nranks-{}".format(exp_dir, exp_name, ds_name, nranks)
    model_dir = "_".join(["{}-{}".format(key,val) for key, val in hyperparams_dict.items()])
    exp_path = "{}/{}/{}".format(base_model_path, model_dir, current_time)
    Path(exp_path).mkdir(parents=True, exist_ok=True)
    return exp_path

def set_count_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset',         type=str, required=True, help='The name of the dataset')
    parser.add_argument('--dataf',          type=str, required=True, help='the name of the train set file')
    parser.add_argument('--data_dir',        type=str, required=True, help='the path to the data')
    parser.add_argument('--batch_size',      type=int, required=True, default=32)
    parser.add_argument('--epochs',          type=int, required=True, default=200)
    parser.add_argument('--exp_id',          type=int, required=True, default=0)
    parser.add_argument('--exp_name',        type=str, required=True)
    return parser.parse_args()

def set_default_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--Nranks',          type=int, required=True, help='The # of ranking examples to use')
    parser.add_argument('--dataset',         type=str, required=True, help='The name of the dataset')
    parser.add_argument('--trainf',          type=str, required=True, help='the name of the train set file')
    parser.add_argument('--evalf',           type=str, required=True, help='the name of the eval set file')
    parser.add_argument('--data_dir',        type=str, required=True, help='the path to the data')
    parser.add_argument('--batch_size',      type=int, required=True, default=32)
    parser.add_argument('--epochs',          type=int, required=True, default=200)
    parser.add_argument('--exp_id',          type=int, required=True, default=0)
    parser.add_argument('--exp_name',        type=str, required=True)
    parser.add_argument('--pairwise',        type=int, required=True)
    parser.add_argument('--equivariant',     type=int, required=True)
    parser.add_argument('--prototype',       type=int, required=True)
    parser.add_argument('--spatial_dropout', type=int, required=True)
    parser.add_argument('--mae_dmap',        type=int, required=True)
    parser.add_argument('--ood_norm',        type=int, required=True)
    parser.add_argument('--no_rank',         type=int, required=True)
    parser.add_argument('--biased_dmap',     type=int, required=True)
    parser.add_argument('--no_dmap',         type=int, required=True)
    parser.add_argument('--extra_dataf',     type=str)
    parser.add_argument('--extra_data_dir',  type=str)
    return parser.parse_args()

def get_simple_hyperparams():
    lr_sample = np.round(np.square(np.random.uniform(low=np.sqrt(1e-05), high=np.sqrt(2e-03))), 5)

    hyperparams = {"lr-G":     lr_sample,
                   "lr-D":     lr_sample,
                   "beta1":    np.round(np.random.normal(loc=0.5, scale=0.1), 3),
                   "gan-loss": np.random.choice(["gan", "lsgan"])}
    return hyperparams

def get_hyperparams(args):
    lr_sample = np.round(np.square(np.random.uniform(low=np.sqrt(1e-05), high=np.sqrt(2e-03))), 5)

    hyperparams = {"lr-G":     lr_sample,
                   "lr-D":     lr_sample,
                   "beta1":    np.round(np.random.normal(loc=0.5, scale=0.1), 3),
                   "gan-loss": np.random.choice(["gan", "lsgan"])}
    hp_args = [[args.equivariant,     "equivariant",     [0.2, 0.5, 1, 1.5]],
               [args.prototype,       "prototype",       [1., 2., 5., 10., 20.]],
               [args.spatial_dropout, "spatial_dropout", [0.1]],
               [args.mae_dmap,        "mae",             [0.25, 0.5, 1.0, 1.5]],
               [args.ood_norm,        "ood",             [0.1, 0.3, 0.5, 1.0, 1.5]]]

    for argument, name, choice in hp_args:
        if argument == 1:
            if len(choice) == 1:
                hyperparams[name] = choice[0]
            else:
                hyperparams[name] = np.random.choice(choice)

    return hyperparams

def get_loss_dict(hyperparams):
    # G-rank, D-rank, G-advr, D-advr, equi
    rank_loss = BinaryCrossentropy(from_logits=True)
    equi_loss = MeanSquaredError()
    if hyperparams["gan-loss"] == "lsgan":
        adv_loss = MeanSquaredError()
    elif hyperparams["gan-loss"] == "gan":
        adv_loss = BinaryCrossentropy(from_logits=True)
    else:
        raise ValueError("Received unknown loss-type {}".format(hyperparams["gan-loss"]))


    loss_dict = {"G": {"advr": adv_loss,
                       "rank": rank_loss,
                       "count": MeanSquaredError()},

                 "D": {"advr": adv_loss,
                       "rank": rank_loss}}

    loss_w_dict = {"G": {"advr": 1,
                         "rank": 1},

                   "D": {"advr": 1,
                         "rank": 1}}


    if "ood" in hyperparams.keys():
        loss_w_dict["G"]["ood"] = hyperparams["ood"]

    return loss_dict, loss_w_dict


def get_ds_and_model(args):
    no_rank = args.no_rank
    biased_dmap = args.biased_dmap
    no_dmap = args.no_dmap

    ExtraDataIterator = None
    if no_dmap == 1:
        GANModel = NoDmapRank
    elif no_rank == 1:
        GANModel = DMapNoRank
    else:
        GANModel = DMapGAN

    if biased_dmap == 1:
        DataIterator = BiasedDmapDSIterator
    elif no_dmap == 1:
        DataIterator = RankDSIterator
    else:
        DataIterator = DMapDSIterator

    return DataIterator, GANModel, ExtraDataIterator

def get_count_ds_and_model(args):
    return DMapCountDSIterator, DMapCountGAN, None
