"""Analog of tf.gradients for use in computing fishers.

See the following files for how tf.gradients is implemented:
    - https://github.com/tensorflow/tensorflow/blob/v2.4.0/tensorflow/python/ops/gradients_impl.py#L177
    - https://github.com/tensorflow/tensorflow/blob/v2.4.0/tensorflow/python/ops/gradients_util.py#L480

Note that the above links are specific for tensorflow-2.4.0, which is what I am using.

"""
import collections
from typing import List

import tensorflow as tf

from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import function as framework_function
from tensorflow.python.framework import ops
from tensorflow.python.framework.func_graph import FuncGraph
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_state
from tensorflow.python.ops import default_gradient
from tensorflow.python.ops import gradients_util
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
from tensorflow.python.util import object_identity

###############################################################################
# Copying over some functions defined by the implementation of tf.gradients.

_AsList = gradients_util._AsList
_IsFunction = gradients_util._IsFunction
_PendingCount = gradients_util._PendingCount
_SetGrad = gradients_util._SetGrad
_StopOps = gradients_util._StopOps
_AggregatedGrads = gradients_util._AggregatedGrads
_IsPartitionedCall = gradients_util._IsPartitionedCall
_Inputs = gradients_util._Inputs


_maybe_colocate_with = gradients_util._maybe_colocate_with


###############################################################################
# For temporary debugging and learning what the method does.
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


_DEBUG = dotdict()


###############################################################################

def diagonal_fishers(logits: tf.Tensor, variables: List[tf.Variable], per_example: bool) -> List[tf.Tensor]:
    log_probs = tf.math.log_softmax(logits, axis=-1)
    helper = _FishersHelper(log_probs=log_probs, variables=variables, per_example=per_example)
    return helper.create_ops()


###############################################################################

class _FishersHelper:

    def __init__(
        self,
        log_probs: tf.Tensor,
        variables: List[tf.Variable],
        per_example: bool,
        #
        name: str = 'fishers',
        stop_gradients=None,
        colocate_gradients_with_ops: bool = True,
        unconnected_gradients=UnconnectedGradients.NONE,
        src_graph=None,
    ):
        self._assert_in_graph_context()

        self.log_probs = log_probs
        self.variables = variables
        self.per_example = per_example
        self.name = name
        self.colocate_gradients_with_ops = colocate_gradients_with_ops

        # Used by the tf.gradients code for aggregrating gradients. Not used by me directly.
        self.aggregation_method = tf.AggregationMethod.DEFAULT

        self.stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)

        if src_graph is None:
            src_graph = ops.get_default_graph()
        self.src_graph = src_graph

        try:
            self.unconnected_gradients = UnconnectedGradients(unconnected_gradients)
        except ValueError:
            raise ValueError(
                "Unknown value for unconnected_gradients: %r" % unconnected_gradients)

        # If src_graph is a _FuncGraph (i.e. a function body), gather it and all
        # ancestor graphs. This is necessary for correctly handling captured values.
        self.func_graphs = self._make_func_graphs()

    ############################################################

    def _assert_in_graph_context(self):
        if context.executing_eagerly():
            raise RuntimeError('npeff.diagonal_fishers is supported only in a graph context.')

    ############################################################

    def _get_dummy_grad_log_probs(self):
        # TODO: The original version does some type checking as well, which I ignore
        # here for brevity
        y = self.log_probs
        return array_ops.fill(
            array_ops.shape(y),
            constant_op.constant(1, dtype=y.dtype, name="dummy_grad_log_probs"))

    def _make_func_graphs(self):
        func_graphs = []
        curr_graph = self.src_graph
        while _IsFunction(curr_graph):
            func_graphs.append(curr_graph)
            if isinstance(curr_graph, FuncGraph):
                curr_graph = curr_graph.outer_graph
            else:
                assert isinstance(curr_graph, framework_function._FuncGraph)
                curr_graph = curr_graph._outer_graph
        return func_graphs

    def _convert_variables_to_indexed_slices(self, variables):
        xs = [
            x.handle if resource_variable_ops.is_resource_variable(x) else x
            for x in variables
        ]
        return ops.internal_convert_n_to_tensor_or_indexed_slices(
            xs, name="variable", as_ref=True)

    def _set_grad(self, y, grad_y):
        return _SetGrad(self.grads, y, grad_y)

    def _initialize_grads(self):

        # Initialize the pending count for ops in the connected subgraph from
        # log_probs to the variables.
        self.to_ops = [self.log_probs.op]
        self.from_ops = [t.op for t in self.variables_as_slices]
        self.reachable_to_ops, self.pending_count, self.loop_state = _PendingCount(
            self.to_ops, self.from_ops, self.colocate_gradients_with_ops, self.func_graphs, self.variables_set)

        # # Iterate over the collected ops.
        # #
        # # grads: op => list of gradients received on each output endpoint of the
        # # op.  The gradients for each endpoint are initially collected as a list.
        # # When it is time to call the op's gradient function, for each endpoint we
        # # aggregate the list of received gradients into a Add() Operation if there
        # # is more than one.
        self.grads = {}

        # Add the initial gradients for the log_probs.
        self._set_grad(self.log_probs, self._get_dummy_grad_log_probs())

    def _initialize_queue(self):
        # Initialize queue with to_ops.
        self.queue = collections.deque()
        # Add the ops in 'to_ops' into the queue.
        to_ops_set = set()
        for op in self.to_ops:
            # 'ready' handles the case where one output gradient relies on
            # another output's gradient.
            ready = (self.pending_count[op] == 0)
            if ready and op not in to_ops_set and op in self.reachable_to_ops:
                to_ops_set.add(op)
                self.queue.append(op)

        if self.loop_state:
            loop_exits = self.loop_state.ProcessUnusedLoopExits(self.pending_count, to_ops_set)
            for y in loop_exits:
                if backprop_util.IsTrainable(y):
                    self._set_grad(y, self.loop_state.ZerosLikeForExit(y))
                    self.queue.append(y.op)

        self.stop_gradient_ops = [t.op for t in self.stop_gradients]
        self.stop_ops = _StopOps(self.from_ops, self.stop_gradient_ops, self.pending_count, self.variables_set)

    ############################################################

    def _get_out_grads(self, op):
        if self.loop_state:
            self.loop_state.EnterGradWhileContext(op, before=True)
        out_grads = _AggregatedGrads(self.grads, op, self.fishers_uid, self.loop_state,
                                     self.aggregation_method)
        if self.loop_state:
            self.loop_state.ExitGradWhileContext(op, before=True)

        return out_grads

    ############################################################
    
    # TODO: This is the "main" function. Maybe rename if something else describes it better.
    def create_ops(self):
        self._assert_in_graph_context()

        name_scope_values = [self.log_probs, *self.variables, *self.stop_gradients]
        with ops.name_scope(self.name, "fishers", name_scope_values) as grad_scope:
            # Get a uid for this call to fishers that can be used to help
            # cluster ops for compilation.
            self.fishers_uid = ops.get_default_graph().unique_name("uid")

            self.variables_as_slices = self._convert_variables_to_indexed_slices(self.variables)
            self.variables_set = object_identity.ObjectIdentitySet(self.variables_as_slices)
            
            ## TODO: Update this description when I am done.
            # The approach we take here is as follows: Create a list of all ops in the
            # subgraph between the ys and xs.  Visit these ops in reverse order of ids
            # to ensure that when we visit an op the gradients w.r.t its outputs have
            # been collected.  Then aggregate these gradients if needed, call the op's
            # gradient function, and add the generated gradients to the gradients for
            # its input

            self._initialize_grads()
            self._initialize_queue()

            # _DEBUG.ops = list(self.queue)

            while self.queue:
                # generate gradient subgraph for op.
                op = self.queue.popleft()
                with _maybe_colocate_with(op, self.fishers_uid, self.colocate_gradients_with_ops):
                    op_helper = _FwdOpHelper(self, op)


###############################################################################


class _FwdOpHelper:

    def __init__(self, helper, op):
        self.helper = helper
        self.op = op

        self.out_grads = self.helper._get_out_grads(op)

        self._find_grad_fn()

        if self.helper.loop_state:
            self.helper.loop_state.EnterGradWhileContext(op, before=False)

        self._some_check()

        if self._need_to_compute_in_grads():
            self.in_grads = self._compute_in_grads()
        else:
            # If no grad_fn is defined or none of out_grads is available,
            # just propagate a list of None backwards.
            self.in_grads = [None] * len(_Inputs(op, self.helper.variables_set))

    ############################################################

    def _find_grad_fn(self):
        op = self.op

        # Stuff we are setting in this method call.
        self.grad_fn = None
        self.func_call = None
        self.is_partitioned_call = _IsPartitionedCall(op)
        self.is_func_call = (
            self.helper.src_graph._is_function(op.type) or self.is_partitioned_call)
        self.has_out_grads = any(isinstance(g, ops.Tensor) or g for g in self.out_grads)

        if self.has_out_grads and (op not in self.helper.stop_ops):
            try:
                self.grad_fn = ops.get_gradient_function(op)
            except LookupError as e:
                # TODO: See
                # https://github.com/tensorflow/tensorflow/blob/v2.4.0/tensorflow/python/ops/gradients_util.py#L610-L638
                raise e

    ############################################################

    def _need_to_compute_in_grads(self):
        return (self.grad_fn or self.is_func_call) and self.has_out_grads

    def _compute_in_grads(self):
        assert self._need_to_compute_in_grads()

        op = self.op

        # NOTE: If _AggregatedGrads didn't compute a value for the i'th
        # output, it means that the cost does not depend on output[i],
        # therefore dC/doutput[i] is 0.
        for i, out_grad in enumerate(self.out_grads):
            self._zero_out_grad_if_needed(i)

        with ops.name_scope(op.name + "_fisher"):
            with self.helper.src_graph._original_op(op):
                print(self.grad_fn)
                if self.grad_fn:
                    # TODO
                    pass
                else:
                    # TODO
                    pass

    def _zero_out_grad_if_needed(self, index: int) -> bool:
        op = self.op
        out_grad = self.out_grads[index]
        loop_state = self.helper.loop_state

        if (not isinstance(out_grad, ops.Tensor) and not out_grad) and (
            (not self.grad_fn and self.is_func_call)
            or backprop_util.IsTrainable(self.op.outputs[index])
        ):
            # Only trainable outputs or outputs for a function call that
            # will use SymbolicGradient get a zero gradient. Gradient
            # functions should ignore the gradient for other outputs.
            # TODO(apassos) gradients of resource handles might be an
            # issue here because of zeros.
            if loop_state:
                self.out_grads[index] = loop_state.ZerosLikeV1WhileLoop(op, index)
            elif default_gradient.supports_default_grad(op.outputs[index]):
                # TODO(b/143286622): The supports_default_grad check is needed
                # because While op emits non-differentiable resource tensors
                # as outputs. Remove this check when that is not the case.
                self.out_grads[index] = control_flow_state.ZerosLike(op, index)

    ############################################################

    def _some_check(self):
        # Just this overly-restrictive check for now. I don't think I will
        # encounter cases like this in the models that I am working with.
        assert self.op._control_flow_context is None
        # From:
        # https://github.com/tensorflow/tensorflow/blob/v2.4.0/tensorflow/python/ops/gradients_util.py#L642-L654
        #
        # # NOTE(skyewm): We don't support computing gradients wrt a loop variable
        # # unless it's within the context of a single iteration (i.e. the
        # # gradient is wrt to the loop parameter in the body function, not wrt or
        # # through the initial value). This means if we're in a while loop
        # # context, we should never see a switch node from this context.
        # # pylint: disable=protected-access
        # if (control_flow_util.IsSwitch(op) and
        #     op._control_flow_context is not None and
        #     op._control_flow_context.IsWhileContext() and
        #     op._control_flow_context ==
        #     ops.get_default_graph()._get_control_flow_context()):
        #   _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set)
        # # pylint: enable=protected-access
        pass
