# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.
"""
Module for computational graph execution.

Classes:
    Task: Abstract base class representing a computational task.
    Executor: Class for scheduling and executing directed acyclic task graphs.
"""

from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union

import networkx
import torch
import tqdm
from pydantic import BaseModel
from typing_extensions import Generic, TypeVar

ValueT = TypeVar("ValueT")


class Task(ABC, BaseModel, Generic[ValueT], frozen=True):
    """
    Abstract base class representing a task in a computational graph.

    This class should be extended to define specific tasks. Each task can have arguments (dependencies) and a defined execution strategy.

    Attributes:
        Generic[ValueT] (TypeVar): The type of the value that the task returns upon execution.

    Methods:
        arguments: Abstract method to define task arguments (dependencies).
        execute: Abstract method to execute the task.
        priority: Returns the priority of the task for scheduling purposes.
        group_label: Returns an optional label for task grouping.
    """

    @abstractmethod
    def arguments(self) -> Dict[str, "Task"]:
        """
        Returns a dictionary of arguments required for this task. The keys of the dictionary
        are argument names, and the values are Task instances. These keys correspond to the
        keyword argument names expected by the execute method.

        For example, if this method returns {'input1': taskA, 'input2': taskB}, the execute
        method should expect to be called as execute(input1=valueA, input2=valueB), where
        valueA and valueB are the outputs of taskA and taskB respectively.

        Returns:
            Dict[str, "Task"]: A dictionary mapping argument names to Task instances.
        """
        ...

    @abstractmethod
    def execute(self, **kwargs) -> ValueT:
        """
        Executes the task using the results of its dependencies.

        The keyword arguments (**kwargs) for this method are dynamically determined based on
        the dictionary returned by the 'arguments' method. Each key in the 'arguments' method's
        return dictionary becomes a keyword argument in this method, with its value being
        the result of the corresponding task's execution.

        Returns:
            ValueT: The result of the task execution.
        """
        ...

    def priority(self) -> int:
        """
        Returns the priority of the task for scheduling.

        Higher numbers indicate higher priority. Default is 0.

        Returns:
            int: The priority of the task.
        """
        return 0

    def group_label(self) -> Optional[str]:
        """
        Returns an optional label used for grouping tasks together.

        Returns:
            Optional[str]: The group label of the task, if any.
        """
        return None

    def uses_accelerator(self) -> bool:
        """
        Returns True if the task can take advantage of matrix operation
        acceleration (such as on a GPU).
        """
        return False


class Executor:
    """
    Schedules and executes a set of tasks and their dependencies.

    Handles scheduling, execution, the movement of data between devices, and the lifecycle of intermediate results.

    Attributes:
        math_device (torch.device): Device used for tensor computations.
        storage_device (torch.device): Device used for storing intermediate results.
        targets (List[Task]): List of target tasks to be executed.
        schedule (List[Task]): Calculated execution schedule of tasks.
        dependencies (Dict[Task, Set[Task]]): Dependencies of each task.
    """

    math_device: torch.device
    storage_device: torch.device
    targets: List[Task]
    schedule: List[Task]
    dependencies: Dict[Task, Set[Task]]

    def __init__(
        self,
        tasks: List[Task],
        math_device: torch.device = torch.device("cpu"),
        storage_device: torch.device = torch.device("cpu"),
    ):
        """
        Initializes the Executor with a list of tasks and device configurations.

        Args:
            tasks (List[Task]): The list of tasks to be executed.
            math_device (torch.device, optional): The device for tensor computations. Defaults to CPU.
            storage_device (torch.device, optional): The device for storing results. Defaults to CPU.
        """
        self.math_device = math_device
        self.storage_device = storage_device
        self.schedule = self._make_schedule(tasks)
        self.targets = tasks

    def run(self, quiet: bool = False) -> Iterator[Tuple[Task, Any]]:
        """
        Execute the computed schedule and yield the target values.

        Yields:
            Iterator[Tuple[Task, Any]]: An iterator of task-result pairs.
        """
        # determine last usage of each value, so they can be evicted afterwards
        last_use_index = {}
        for idx, task in reversed(list(enumerate(self.schedule))):
            for t in self.dependencies[task]:
                if t not in last_use_index:
                    last_use_index[t] = idx
            if task not in last_use_index:
                last_use_index[task] = idx

        values: Dict[Task, Any] = {}
        for idx, task in (
            pbar := tqdm.tqdm(
                list(enumerate(self.schedule)),
                disable=quiet,
                desc="Executing graph",
            )
        ):
            use_math_device = task.uses_accelerator()

            arguments = {}
            for name, dep in task.arguments().items():
                value = values[dep]

                # ensure any input tensors are on math device if task asks for it
                if use_math_device:
                    if (
                        isinstance(value, torch.Tensor)
                        and value.device != self.math_device
                    ):
                        value = value.to(self.math_device)
                    elif isinstance(value, dict):
                        for key in value:
                            if (
                                isinstance(value[key], torch.Tensor)
                                and value[key].device != self.math_device
                            ):
                                value[key] = value[key].to(self.math_device)

                arguments[name] = value
                del value

            res = task.execute(**arguments)
            del arguments

            if isinstance(res, torch.Tensor) and res.device != self.storage_device:
                res = res.to(self.storage_device)

            values[task] = res
            del res

            if task in self.targets:
                yield (task, values[task])

            # evict unreferenced values
            expired = []
            for key in values:
                if idx >= last_use_index[key]:
                    expired.append(key)

            for key in expired:
                del values[key]

        del values
        del pbar

    def execute(self) -> None:
        """
        Execute all tasks and discard results.
        """
        for task, value in self.run():
            pass

    DUMMY_TASK_VALUE = "!!DUMMY!!"

    def _make_schedule(self, targets: List[Task]) -> List[Task]:
        self.schedule = []
        self.dependencies = self._build_dependencies(targets)

        edge_tups = []
        for node in self.dependencies:
            for dependency in self.dependencies[node]:
                edge_tups.append((dependency, node))

        for task in targets:
            # add edges from a dummy node to each target to guarantee
            # they will be included in the final schedule
            edge_tups.append((Executor.DUMMY_TASK_VALUE, task))

        def _compare_key(task: Union[Task, str]):
            if task == Executor.DUMMY_TASK_VALUE:
                return ("", 0)
            return (
                task.group_label() or "",
                -task.priority(),
            )

        graph = networkx.DiGraph(edge_tups)
        res = [
            t
            for t in networkx.lexicographical_topological_sort(graph, key=_compare_key)
            if t != Executor.DUMMY_TASK_VALUE
        ]
        return res

    def _build_dependencies(self, targets: List[Task]) -> Dict[Task, Set[Task]]:
        task_dependencies: Dict[Task, Set[Task]] = {}
        to_process = list(targets)
        while to_process:
            child = to_process.pop()
            if child in task_dependencies:
                continue

            task_dependencies[child] = set()
            for _, dep in child.arguments().items():
                task_dependencies[child].add(dep)
                to_process.append(dep)
        return task_dependencies
