import pathlib
import subprocess

import tensorflow as tf
from tensorflow.python.keras.utils.layer_utils import count_params
from tensorflow.python.profiler import model_analyzer
from tensorflow.python.profiler.option_builder import ProfileOptionBuilder


class Location:
    def __init__(self, machine, path):
        self.machine = machine
        self.path = pathlib.Path(path)

        self._contents = None

    @property
    def contents(self):
        if self._contents is None:
            self._contents = self.ls_remote_dir()

        return self._contents

    def ls_remote_dir(self, path=None):
        path = str(self.path if path is None else path)
        result = subprocess.run(["ssh", self.machine, "ls", path], capture_output=True)
        if result.returncode == 0:
            return [r for r in result.stdout.decode("ascii").split("\n") if r.strip()]
        else:
            print(f"Could not list '{path}' on '{self.machine}': {result.stderr}")
            return []

    def sync_remote_subdir(self, subdir, target_dir):
        remote = f"{self.machine}:{str(self.path / subdir)}"
        target_dir = pathlib.Path(target_dir)
        new_path = target_dir / subdir
        result = subprocess.run(
            ["rsync", "-av", remote, str(target_dir)],
            capture_output=True,
        )
        if result.returncode == 0:
            print(f"Synced {remote}")
        else:
            print(f"Could not sync {remote}")

        return new_path


def count_embedding_params(model):
    """Counts the number of parameters in layers containing the name "embedding".

    Does not consider the structure of nested/custom layers or layers that do not
    contain the word "embedding" in the layer name.
    """
    embedding_params = 0
    for layer in model.layers:
        if "embedding" in layer.name:
            embedding_params += layer.count_params()
        # for attr in dir(layer):
        #     if isinstance(attr, tf.keras.layers.Layer):
        #         embedding_params += count_embedding_params(attr)
    return embedding_params


def count_nontrainable_params(model):
    """Counts the number of non-trainable parameters."""
    return count_params(model.non_trainable_weights)


def count_flops_per_step(model):
    forward_pass = tf.function(
        model.call, input_signature=[tf.TensorSpec(shape=(1,) + model.input_shape[1:])]
    )
    graph_info = model_analyzer.profile(
        forward_pass.get_concrete_function().graph,
        options=ProfileOptionBuilder.float_operation(),
    )

    flops = graph_info.total_float_ops
    seq_len = model.input_shape[1]
    # assert flops % seq_len == 0  # not true for some LMU models, e.g. LMUMLP
    return flops // seq_len


def parameter_summary(model, width=98, print_fn=print, embedding_params=None):
    """Displays a summary of parameter counts and returns the relevant params.

    'Relevant' parameters here means non-embedding & trainable parameters.
    """
    total_params = model.count_params()
    nontrainable_params = count_nontrainable_params(model)
    if embedding_params is None:
        embedding_params = count_embedding_params(model)
    relevant_params = total_params - nontrainable_params - embedding_params
    print_fn(f"Total params: {total_params}")
    print_fn(f"Non-trainable params: {nontrainable_params}")
    print_fn(f"Embedding params: {embedding_params}")
    print_fn(f"Relevant params: {relevant_params}")
    print_fn("_" * width)
    return relevant_params


def PaddedSparseCategoricalCrossentropy(
    pad_value=0, from_logits=True, label_dtype=tf.int32
):
    """Wraps ``SparseCategoricalCrossentropy`` to avoid elements where
    ``y_true == pad_value``."""
    if not from_logits:
        raise NotImplementedError("from_logits must be True")

    pad_value = tf.cast(pad_value, label_dtype)

    def padded_sparse_softmax_cross_entropy(y_true, y_pred):
        # SparseCategoricalCrossentropy does not take a weights argument but the
        #  v1 function does. Wrapping that appears to be the safest and most robust way
        #  to do the masking.
        y_true = tf.cast(y_true, label_dtype)
        mask = y_true != pad_value
        return tf.compat.v1.losses.sparse_softmax_cross_entropy(
            labels=y_true, logits=y_pred, weights=mask
        )

    return padded_sparse_softmax_cross_entropy
