import math
import random
import sys
import os
import functools
import datetime as dt
from typing import Any, Dict, List, Tuple
from collections import Counter

import matplotlib.pyplot as plt
import matplotlib
import tensorflow as tf
import torch
import numpy as np
import h5py

import tools
import pytorch_models
import path_config

import acas.writeNNet


# https://stackoverflow.com/questions/45667439/how-to-type-annotate-tensorflow-session-in-python3

# From: https://arxiv.org/pdf/1912.07084.pdf
# The networks have three inputs, one for each remaining
# state variable. Each network has five hidden layers of
# length 25 and nine outputs, one for each advisory. Nine
# smaller networks are trained instead of a single large
# network to reduce run-time required to evaluate each network.
# Each network was trained for 3000 epochs in30 minutes using
# Tensorflow, resulting in nine neural networks that reduce
# required storage from 1.22GB to 103KB while maintaining the
# correct advisory 94.9% of the time.


def get_session(gpu_inds: List[int],
                gpu_mem_frac: float) -> tf.Session:
    """Create a session that dynamically allocates memory."""
    if gpu_inds[0] > -1:
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(np.char.mod("%d", gpu_inds))
        config = tf.ConfigProto(device_count={"GPU": len(gpu_inds)})
        config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
        sess = tf.Session(config=config)
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
        sess = tf.Session()
    return sess


def asymMSE(y_true, y_pred):
    # print(y_true.shape[1])
    num_out = y_true.shape[1].value
    assert y_pred.shape[1] == num_out

    lossFactor = 40.0

    d = y_true - y_pred
    argmaxes = tf.argmax(y_true, axis=1)
    maxes_onehot = tf.one_hot(argmaxes, num_out)

    others_onehot = maxes_onehot - 1

    d_opt = d * maxes_onehot

    d_sub = d * others_onehot
    a = lossFactor * (num_out - 1) * (tf.square(d_opt) + tf.abs(d_opt))
    b = tf.square(d_opt)
    c = lossFactor * (tf.square(d_sub) + tf.abs(d_sub))
    d = tf.square(d_sub)
    loss = tf.where(d_sub > 0, c, d) + tf.where(d_opt > 0, a, b)
    val = tf.reduce_mean(loss)
    return val


def custAcc(y_true, y_pred):
    maxesPred = tf.argmax(y_pred, axis=1)
    inds = tf.argmax(y_true, axis=1)
    diff = tf.cast(tf.abs(inds - maxesPred), dtype="float32")
    ones = tf.ones_like(diff, dtype="float32")
    zeros = tf.zeros_like(diff, dtype="float32")
    l = tf.where(diff < 0.5, ones, zeros)
    return tf.reduce_mean(l)


def fit_model_tensorflow(
    x_train: np.ndarray,
    q: np.ndarray,
    lr: float,
    epochs: int,
    batch_size: int,
    ident: str,
    hcas_root: str) -> Tuple[np.ndarray, List[np.ndarray], List[np.ndarray]]:

    test_size = 1000
    utcnowstr = dt.datetime.utcnow().strftime("%Y%m%d%H%M%S")

    tb_dir = os.path.join(hcas_root, "tensorboard")
    os.makedirs(tb_dir, exist_ok=True)

    tb_filename = "{}_{}".format(ident, utcnowstr)
    tb_fullfilename = os.path.join(tb_dir, tb_filename)

    print("Run 'tensorboard --logdir {}' (This run is called '{}')".format(tb_dir, tb_filename))
    n, num_inputs = x_train.shape
    _, num_outputs = q.shape
    assert _ == n, "Size mismatch"

    tf.reset_default_graph()
    sess = get_session([gpu], 0.45)

    x = tf.placeholder(tf.float32, [None, num_inputs])
    y = tf.placeholder(tf.float32, [None, num_outputs])

    layer_sizes = np.concatenate(([num_inputs], hidden_layer_sizes, [num_outputs]))
    num_layers = len(layer_sizes)

    vd = {}
    weight_names = [None] * (num_layers - 1)
    bias_names = [None] * (num_layers - 1)

    z = x
    for i, (in_layer, out_layer) in enumerate(zip(layer_sizes[:-1],
                                                  layer_sizes[1:])):
        wis = "W" + str(i)
        bis = "b" + str(i)

        wi = tf.get_variable(wis, shape=[out_layer, in_layer])
        bi = tf.get_variable(bis, shape=[out_layer, 1])

        lin = tf.transpose(tf.matmul(wi, tf.transpose(z)) + bi)

        if i < num_layers - 2:
            z = tf.nn.relu(lin)
        else:
            z = lin
        vd[wis] = wi
        vd[bis] = bi
        weight_names[i] = wis
        bias_names[i] = bis

    y_out = z
    # use_cross_entropy = False
    use_cross_entropy = True
    if use_cross_entropy:
        loss = cross_entropy_loss(y, y_out)
    else:
        loss = asymMSE(y, y_out)

    optimizer = tf.train.AdamOptimizer(lr)
    train_step = optimizer.minimize(loss)

    sess.run(tf.global_variables_initializer())

    flush_secs = 1
    writer = tf.summary.FileWriter(tb_fullfilename, flush_secs=flush_secs)
    writer.add_graph(sess.graph)

    print_every = 200
    # print_every = 1000

    cust_accuracy = custAcc(y, y_out)

    tf.summary.scalar("cust_accuracy", cust_accuracy)
    tf.summary.scalar("loss", loss)

    merged_summary = tf.summary.merge_all()

    num_rows = x_train.shape[0]
    accuracies = np.full((epochs,), np.nan)
    for e in range(epochs):
        train_indices = np.arange(num_rows)
        np.random.shuffle(train_indices)

        num_batches = int(math.ceil(num_rows / batch_size))
        for i in range(num_batches):
            start_idx = (i * batch_size) % num_rows
            until_idx = np.minimum(start_idx + batch_size, num_rows)
            idx = train_indices[start_idx:until_idx]
            feed_dict = {x: x_train[idx, :], y: q[idx, :]}
            sess.run([train_step], feed_dict=feed_dict)

        test_inds = np.random.choice(num_rows, test_size, replace=False)
        feed_dict = {x: x_train[test_inds, :], y: q[test_inds, :]}
        s = sess.run(merged_summary, feed_dict=feed_dict)
        writer.add_summary(s, e + 1)
        gg = sess.run([cust_accuracy, loss], feed_dict=feed_dict)
        c = gg[0]
        m = gg[1]
        accuracies[e] = c

        if e % print_every == 0:
            print("Epoch {:>4d} / {:>4d}: (accuracy, loss) = ({:.3f}, {:4.3f})".format(e, epochs, c, m))

        params = sess.run(vd)

    weights = [params[wn].T for wn in weight_names]
    biases = [params[bn].flatten() for bn in bias_names]

    return accuracies, weights, biases


def fit_model_pytorch(train_x: np.ndarray,
                      q: np.ndarray,
                      lr: float,
                      epochs: int,
                      hidden_layer_sizes: List[int],
                      batch_size: int,
                      criterion_name: str) -> Tuple[np.ndarray, List[np.ndarray], List[np.ndarray]]:
    # log_every_epoch = 1
    log_every_epoch = 10
    # log_every_epoch = 100

    output_width = q.shape[1]
    train_y = np.argmax(q, axis=1)

    x_torch = torch.from_numpy(train_x).type(torch.FloatTensor)
    y_torch = torch.from_numpy(train_y).type(torch.LongTensor)

    optim = torch.optim.Adam
    optim_kwargs = {"lr": lr, "betas": (0.9, 0.999)}

    if "hinge" == criterion_name:
        criterion = torch.nn.MultiMarginLoss()
    elif "cross_entropy":
        criterion = torch.nn.CrossEntropyLoss()

    train_dataset = torch.utils.data.TensorDataset(x_torch, y_torch)

    dataloader = torch.utils.data.DataLoader(train_dataset,
                                             batch_size=batch_size)
    input_width = train_x.shape[1]
    hidden_layer_widths = hidden_layer_sizes

    include_bias = True
    layer_list = pytorch_models.build_relu_layers(input_width,
                                                  hidden_layer_widths,
                                                  output_width,
                                                  include_bias)
    model = pytorch_models.Net(layer_list)

    optimizer = optim(model.parameters(), **optim_kwargs)

    losses = np.full((epochs,), np.nan)
    accuracies = np.full((epochs,), np.nan)

    for epoch_idx in range(epochs):
        for batch_idx, (x, y) in enumerate(dataloader):
            y_pred = model.forward(x)
            loss = criterion(y_pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        yhat = torch.argmax(model(x_torch), axis=1).detach().numpy()
        acc = np.mean(yhat == train_y)
        accuracies[epoch_idx] = acc
        losses[epoch_idx] = loss.item()
        if epoch_idx % log_every_epoch == 0:
            print("Epoch completed {:4d} / {:4d} -- loss, acc = {:.4f}, {:.4f}".format(epoch_idx, epochs, losses[epoch_idx], accuracies[epoch_idx]))

    layers = model.layers
    linear_layers = [_ for _ in layers if type(_) == torch.nn.modules.linear.Linear]
    weights_and_biases = [_get_weight_and_bias_from_linearlike_layer(_) for _ in linear_layers]

    weights = [wb[0].T for wb in weights_and_biases]
    biases = [wb[1].flatten() for wb in weights_and_biases]
    return accuracies, weights, biases


def cross_entropy_loss(y: tf.Tensor, y_out: tf.Tensor) -> tf.Tensor:
    labels = tf.math.argmax(y, axis=1)
    logits = tf.convert_to_tensor(y_out)
    xel = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
                                                         logits=logits)
    loss = tf.math.reduce_mean(xel)
    return loss


def _get_weight_and_bias_from_linearlike_layer(layer: Any) -> Tuple[np.ndarray, np.ndarray]:
    param_list = list(layer.parameters())
    assert 2 >= len(param_list), "Expecting at most a weight and a bias"
    w_prev = param_list[0].detach().numpy()
    if layer.bias is None:
        b_prev = np.zeros((w_prev.shape[0], 1))
    else:
        b_prev = tools.vec(param_list[1].detach().numpy())
    return w_prev, b_prev


def _do_training(ident: str,
                 tau: int,
                 pra: int,
                 lr: float,
                 epochs: int,
                 batch_size: int,
                 criterion_name: str,
                 min_scaled_psi: float,
                 max_scaled_psi: float,
                 do_pytorch_fitting: bool) -> np.ndarray:
    training_data_file_pattern = ident + "_pra{:d}_tau{:02}.h5"
    training_data_filename = training_data_file_pattern.format(pra, tau)

    f_fullfilename = os.path.join(trainingdir, training_data_filename)

    f = h5py.File(f_fullfilename, "r")
    x_train = np.array(f["X"])
    q = np.array(f["y"])

    means = list(f["means"])
    ranges = list(f["ranges"])
    input_mins = list(f["min_inputs"])
    input_maxes = list(f["max_inputs"])

    is_above_min = x_train[:, -1] >= min_scaled_psi
    is_below_max = x_train[:, -1] <= max_scaled_psi
    subset_rows = np.logical_and(is_above_min, is_below_max)

    x_train = x_train[subset_rows, :]
    q = q[subset_rows, :]

    print("Fitting DNN with {} rows".format(x_train.shape[0]))

    # test = "{:03d}".format(psi_deg)
    # ident_pattern = ident + "_pra{:d}_tau{:02d}_relulayers{:03d}_neurons{:03d}"
    ident_pattern = ident + "_pra{:d}_tau{:02d}_relulayers{:03d}_neurons{:03d}_psi{:+03d}"
    # filename = ident_pattern.format(pra, tau, num_relu_layers, neurons_per_layer)
    filename = ident_pattern.format(pra, tau, num_relu_layers, neurons_per_layer, psi_deg)

    if do_pytorch_fitting:
        accuracies, weights, biases = fit_model_pytorch(x_train,
                                                        q,
                                                        lr,
                                                        epochs,
                                                        hidden_layer_sizes,
                                                        batch_size,
                                                        criterion_name)
        filename = filename + "_pytorch"
    else:

        accuracies, weights, biases = fit_model_tensorflow(x_train,
                                                           q,
                                                           lr,
                                                           epochs,
                                                           batch_size,
                                                           ident,
                                                           hcas_root)
    nnet_fullfilename = os.path.join(nnet_dir, filename + ".nnet")
    print("Writing {}".format(nnet_fullfilename))
    acas.writeNNet.writeNNet(weights,
                             biases,
                             input_mins,
                             input_maxes,
                             means,
                             ranges,
                             nnet_fullfilename)
    return accuracies


if __name__ == "__main__":
    gpu = -1
    seed = 10011

    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tf.set_random_seed(seed)

    paths = path_config.get_paths()
    hcas_root = paths['acas']


    tabledir = os.path.join(hcas_root, "GenerateTable")
    trainingdir = os.path.join(hcas_root, "TrainingData")

    nnet_dir = os.path.join(hcas_root, "networks")
    os.makedirs(nnet_dir, exist_ok=True)

    num_relu_layers = 1
    # num_relu_layers = 2
    # num_relu_layers = 2
    # num_relu_layers = 3
    # neurons_per_layer = 8
    # neurons_per_layer = 12
    neurons_per_layer = 16
    # neurons_per_layer = 18
    # neurons_per_layer = 20
    # neurons_per_layer = 24

    # neurons_per_layer = 25

    hidden_layer_sizes = [neurons_per_layer] * num_relu_layers

    do_pytorch_fitting = False

    # subset_tolerance = .10
    # subset_tolerance = .125
    # subset_tolerance = .15
    subset_tolerance = .25
    # subset_tolerance = 1.0
    is_medium = subset_tolerance <= .10
    is_medium_large = .10 < subset_tolerance <= .25

    psi_deg = +90
    # psi_deg = -135
    # psi_deg = +225
    # psi_deg = +45
    # psi_deg = -45

    criterion_name = "hinge"
    if is_medium and 1 == num_relu_layers:
        batch_size = 2 ** 7
        lr = 1e-3
        epochs = 3000
        # epochs = 50
    elif is_medium and 2 == num_relu_layers:
        batch_size = 2 ** 6
        lr = 4e-3
        epochs = 3000
    elif is_medium_large and (neurons_per_layer <= 20) and (1 == num_relu_layers):
        batch_size = 2 ** 8
        # batch_size = 2 ** 7
        # lr = 4 * 1e-3
        # lr = 2 * 1e-3
        lr = 4 * 1e-3
        # epochs = 2000
        epochs = 3000
        # epochs = 1000
    elif is_medium_large and (neurons_per_layer > 20) and (1 == num_relu_layers):
        batch_size = 2 ** 9
        # lr = 6 * 1e-3
        # lr = 8 * 1e-3
        lr = 4 * 1e-3
        epochs = 3000
        criterion_name = "cross_entropy"
        # criterion_name = "hinge"
    else:
        batch_size = 2 ** 9
        # lr = 1e-3
        # lr = 2e-3
        lr = 4e-3
        # epochs = 2000
        epochs = 1000

    # taus = [0, 5, 10, 15, 20, 30, 40, 60]
    # pras = list(range(5))
    taus = [0]
    pras = [0]

    num_taus = len(taus)
    num_pras = len(pras)

    ident = "baseline"

    min_scaled_psi = 0.0
    max_scaled_psi = .50

    accuracies_dict = dict()
    for t_idx, tau in enumerate(taus):
        for p_idx, pra in enumerate(pras):
            print("(tau, pra) = ({}, {})".format(tau, pra))

            a = _do_training(ident,
                             tau,
                             pra,
                             lr,
                             epochs,
                             batch_size,
                             criterion_name,
                             min_scaled_psi,
                             max_scaled_psi,
                             do_pytorch_fitting)
            accuracies_dict[(tau, pra)] = a
            # accuracies[t_idx, p_idx] =

    terminal_accuracies = np.full((num_taus, num_pras), np.nan)
    max_accuracies = np.full((num_taus, num_pras), np.nan)

    for t_idx, tau in enumerate(taus):
        for p_idx, pra in enumerate(pras):
            a = accuracies_dict[(tau, pra)]
            terminal_accuracies[t_idx, p_idx] = np.mean(a[-100:])
            max_accuracies[t_idx, p_idx] = np.max(a)

    plt.imshow(max_accuracies)

    print("{:.4f}".format(np.min(max_accuracies)))
    print("{:.4f}".format(np.mean(max_accuracies)))
    print("{:.4f}".format(np.median(max_accuracies)))

    print("{:.4f}".format(np.min(terminal_accuracies)))
    print("{:.4f}".format(np.mean(terminal_accuracies)))
    print("{:.4f}".format(np.median(terminal_accuracies)))
