"""
MIT License

Copyright (c) 2022 Author(s)

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

import itertools, io
import os, yaml, datetime

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import tensorflow as tf 

def load_yaml(yaml_path):
    assert os.path.exists(yaml_path), "Yaml path does not exist: " + yaml_path
    with open(yaml_path, "r") as f:
        config = yaml.load(f, Loader=yaml.SafeLoader)
    return config


def concat_configs(paths):
    """
    paths: A list of paths to yaml files.
    """
    assert len(paths) > 1
    for cnt, iter_path in enumerate(paths):
        if cnt == 0:
            config = load_yaml(iter_path)
        else:
            add = load_yaml(iter_path)
            for k in add.keys():
                assert not k in config.keys(), "Key {} duplicated.".format(k)
            config.update(add) # concat
    return config


def set_gpu_devices(gpu):
    physical_devices = tf.config.experimental.list_physical_devices('GPU')
    assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
    tf.config.experimental.set_visible_devices(physical_devices[gpu], 'GPU')
    tf.config.experimental.set_memory_growth(physical_devices[gpu], True)


def fix_random_seed(flag_seed, seed=None):
    if flag_seed:
        np.random.seed(seed)
        tf.random.set_seed(seed)
        print("Numpy and TensorFlow's random seeds fixed: seed=" + str(seed))
    
    else:
        print("Random seed not fixed.")


def config_checker(config):
    """ config format checker.
    Args: Dict.
    """
    assert config["exp_phase"] in ["try", "tuning", "stat"]
    assert config["weights_base"] in [None, "imagenet"]
    assert config["weights_top"] in [None, "imagenet"]

    if config["exp_phase"] == "stat":
        assert config["pruner_index"] == 0 # no pruning


def config_checker_rnn(config):
    """ config format checker.
    Args: Dict.
    """
    assert config["exp_phase"] in ["try", "tuning", "stat"]

    if config["exp_phase"] == "stat":
        assert config["pruner_index"] == 0 # no pruning


def save_config_as_yaml(config, root_configs, subproject_name, exp_phase, comment, time_stamp):
    """
    - Remark
    Path to checkpoint files is 
    'root_ckptlogs'/'subproject_name'_'exp_phase'/'comment'_'time_stamp'/config_XXX.yaml

    """
    dir_config = "{}/{}_{}/{}_{}".format(
    root_configs, subproject_name, exp_phase, comment, time_stamp)
    if not os.path.exists(dir_config):
        os.makedirs(dir_config)

    tmp_now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
    path_config = dir_config + "/config_saved{}.yaml".format(tmp_now)
    with open(path_config, "w") as f:
        yaml.dump(config, f)


def path2folders(path):
    """
    - Args
    path: A str. E.g., "/home/x/y/z" and "/raid/a/b/c.txt".
    - Returns
    folders: A list. E.g., ["home", "x", "y", "z"] and ["raid", "a", "b", "c.txt"].
    """
    path_org = path
    
    if path[-1] == "/":
        path = path[:-1]
    
    folders = []
    cnt = 0
    while 1:
        path, folder = os.path.split(path)
        if folder != "":
            folders.append(folder)
        else:
            break
            
        cnt += 1
        if cnt == 100:
            raise ValueError("Infinite loop: path={}".format(path_org))
            
    folders.reverse()

    return folders


def filename_rewriter_if_exists(filepath):
    """ Rewrites filepath, if there exists the same file name.
    - Caution
    Double extensions not supported.
    - Args
    filepath: A str. E.g., "/home/x/y/z.txt", "/data/x/y/z", etc.
    - Returns
    filepath: A str. 
      E.g., "/home/x/y/z(1).txt", "/data/x/y/z(1)", etc. 
      if there already exists the file. Else
      "/home/x/y/z.txt", "/data/x/y/z", etc.
    """
    if not os.path.exists(filepath):
        return filepath

    else:
        # Double extensions not supported.
        _, filename = os.path.split(filepath)
        if filename.count(".") > 1:
            errmsg0 = "Double extensions (e.g., .tar.gz) not supported."
            errmsg1 = "\nGot filepath = {} .".format(filepath)
            raise ValueError(errmsg0 + errmsg1)

        # Rewrite filepath
        from_ = filepath
        pre, ext = os.path.splitext(filepath)
        cnt = 0
        while os.path.exists(filepath):
            cnt += 1
            filepath = pre + "_{}".format(cnt) + ext

        print("File path exists.")
        print("File path has been overwritten\nfrom {}\nto {}".format(from_, filepath))

        return filepath


def permute_in_subindex_order(list_glob):
    assert isinstance(list_glob, list)
    list_glob = sorted(list_glob)    
    if len(list_glob) == 1:
        return list_glob
    else:
        list_order = []
        for i in range(1, len(list_glob)):
            list_order.append(int(os.path.splitext(list_glob[i])[0].split("_")[-1]))

        list_glob_new = [list_glob[0]]
        for i in np.argsort(list_order):
            list_glob_new.append(list_glob[i + 1])

        return list_glob_new


def restrict_classes(llrs, labels, list_classes):
    """ 
    Args:
        llrs: A Tensor with shape (batch, ...). 
            E.g., (batch, duration, num classes, num classes).
        labels: A Tensor with shape (batch, ...). 
            E.g., (batch, ).
        list_classes: A list of integers specifying the classes
            to be extracted. E.g. list_classes = [0,2,9] for NMNIST.
    Returns:
        llrs_rest: A Tensor with shape (<= batch, llrs.shape[:1]). 
            If no class data found in llrs_rest, llrs_rest = None.
        lbls_rest: A Tensor with shape (<= batch, labels.shape[:1]).
            If no class data found in llrs_rest, lbls_rest = None.
    """
    if list_classes == []:
        return llrs, labels

    #assert tf.reduce_min(labels).numpy() <= np.min(list_classes)
    #assert np.max(list_classes) <= tf.reduce_max(labels).numpy() 
    
    ls_idx = []
    for itr_cls in list_classes:
        ls_idx.append(tf.reshape(tf.where(labels == itr_cls), [-1]))
    idx = tf.concat(ls_idx, axis=0)
    idx = tf.sort(idx)
    
    llrs_rest = tf.gather(llrs, idx, axis=0)
    lbls_rest = tf.gather(labels, idx, axis=0)
    
    llrs_rest = None if llrs_rest.shape[0] == 0 else llrs_rest
    lbls_rest = None if lbls_rest.shape[0] == 0 else lbls_rest

    return llrs_rest, lbls_rest


def extract_positive_row(llrs, labels):
    """ Extract y_i-th rows of LLR matrices.
    Args:
        llrs: (batch, duraiton, num classes, num classes)
        labels: (batch,)
    Returns:
        llrs_posrow: (batch, duration, num classes)
    """
    llrs_shape = llrs.shape
    duration = llrs_shape[1]
    num_classes = llrs_shape[2]
    
    labels_oh = tf.one_hot(labels, depth=num_classes, axis=1)
        # (batch, num cls)
    labels_oh = tf.reshape(labels_oh,[-1, 1, num_classes, 1])
    labels_oh = tf.tile(labels_oh, [1, duration, 1, 1])
        # (batch, duration, num cls, 1)

    llrs_pos = llrs * labels_oh
        # (batch, duration, num cls, num cls)
    llrs_posrow = tf.reduce_sum(llrs_pos, axis=2)
        # (batch, duration, num cls): = LLR_{:, :, y_i, :}
        
    return llrs_posrow


def add_max_to_diag(llrs):
    """
    Args:
        llrs: (batch, duration, num classes, num classes)
    Returns:
        llrs_maxdiag: (batch, duration, num classes, num classes),
            max(|llrs|) is added to diag of llrs.
    """
    num_classes = llrs.shape[2]
    
    llrs_abs = tf.abs(llrs)
    llrs_max = tf.reduce_max(llrs_abs)
        # max |LLRs|
    tmp = tf.linalg.tensor_diag([1.] * num_classes) * llrs_max
    tmp = tf.reshape(tmp, [1, 1, num_classes, num_classes])
    llrs_maxdiag = llrs + tmp

    return llrs_maxdiag


def plot_heatmatrix(mx, figsize=(10,7), annot=True):
    """
    Args:
        mx: A square matrix.
        figsize: A tuple of two positive integers.
        annot: A bool. Plot a number at the center of a cell or not.
    """
    plt.figure(figsize=figsize)
    sns.heatmap(mx, annot=annot)
    plt.show()


# https://www.tensorflow.org/tensorboard/image_summaries
def plot_confusion_matrix(cm, class_names):
    """
    Returns a matplotlib figure containing the plotted confusion matrix.

    Args:
    cm (array, shape = [n, n]): a confusion matrix of integer classes
    class_names (array, shape = [n]): String names of the integer classes
    """
    figure = plt.figure(figsize=(8, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title("Confusion matrix")
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)

    # Compute the labels from the normalized confusion matrix.
    labels = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)

    # Use white text if squares are dark; otherwise black.
    threshold = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        color = "white" if cm[i, j] > threshold else "black"
        plt.text(j, i, labels[i, j], horizontalalignment="center", color=color)

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    return figure


def plot_to_image(figure):
    """Converts the matplotlib plot specified by 'figure' to a PNG image and
    returns it. The supplied figure is closed and inaccessible after this call."""
    # Save the plot to a PNG in memory.
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    # Closing the figure prevents it from being displayed directly inside
    # the notebook.
    plt.close(figure)
    buf.seek(0)
    # Convert PNG buffer to TF image
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    # Add the batch dimension
    image = tf.expand_dims(image, 0)
    return image


def calc_normdiff(grads0, grads1):
    """
    - Args:
    grads0, grads1: A list of gradients.
    - Returns:
    normdiff: A list of || grad0 - grad1 ||. L2 norm.
    """
    assert len(grads0) == len(grads1),\
        "grads0 = {} and \ngrads1 = {}.".format(grads0, grads1)
    normdiff = [tf.norm(v - w) for v, w in zip(grads0, grads1)]
    return normdiff


def calc_cossim(grads0, grads1):
    """
    - Args:
    grads0, grads1: A list of gradients.
    - Returns:
    normdiff: A list of cos(grad0, grad1). Takes from -1 to 1.
    """
    assert len(grads0) == len(grads1),\
        "grads0 = {} and \ngrads1 = {}.".format(grads0, grads1)
    cossim = [-tf.keras.losses.cosine_similarity(
        tf.reshape(v, [-1]), tf.reshape(w, [-1])) for v, w in zip(grads0, grads1)]
    return cossim


