from typing import Dict, Text, Callable, List
from collections import defaultdict

"""
The HookManager class is used to manage the registration, invocation, and deregistration of hook functions.

Main features include:
- Registering hook functions to specified names.
- Calling all hook functions associated with a specified name.
- Deregistering hook functions for specific names.
- Creating new HookManager instances to handle nested hook management.
"""

class HookManager(object):
    """
    A class for managing hooks that can be registered, called, and unregistered.
    
    Hooks are organized by name and can be nested using fork operations.
    """
    def __init__(self, hook_dict: Dict[Text, List[Callable]] = None):
        """
        Initialize a HookManager instance.
        
        Args:
            hook_dict: A dictionary mapping hook names to lists of callable functions.
                        If None, an empty defaultdict will be used.
        """
        self.hook_dict = hook_dict or defaultdict(list)
        self.called = defaultdict(int)
        self.forks = dict()

    def register(self, name: Text, func: Callable):
        """
        Register a hook function with the specified name.
        
        If the name has nested components (separated by '.'), the hook may be
        registered in a forked HookManager.
        
        Args:
            name: The name to register the hook under.
            func: The callable function to register.
        """
        assert name
        found_successor = False
        for header, d in self.forks.items():
            if name.startswith(header.split('.')[0]+'.'):
                next_ = name[len(header.split('.')[0]+'.'):].split('.')[0]
                prev_ = header.split('.')[0]
                if next_.isnumeric() and  prev_ + '.' + next_ == header:
                    d.register(name[len(header)+1:], func)
                elif next_ == '*':
                    d.register(name[len(prev_ + '.*')+1:], func)
                else:
                    d.register(name[len(header)+1:], func)
                found_successor = True
        if not found_successor:
            self.hook_dict[name].append(func)

    def unregister(self, name: Text, func: Callable):
        """
        Unregister a hook function from the specified name.
        
        Args:
            name: The name the hook was registered under.
            func: The callable function to unregister.
        """
        assert name
        found_successor = False
        for header, d in self.forks.items():
            if name.startswith(header.split('.')[0]+'.'):
                next_ = name[len(header.split('.')[0]+'.'):].split('.')[0]
                prev_ = header.split('.')[0]
                if next_.isnumeric() and  prev_ + '.' + next_ == header:
                    d.register(name[len(header)+1:], func)
                elif next_ == '*':
                    d.register(name[len(prev_ + '.*')+1:], func)
                else:
                    d.register(name[len(header)+1:], func)
                found_successor = True
        if not found_successor and func in self.hook_dict[name]:
            self.hook_dict[name].remove(func)

    def __call__(self, name: Text, **kwargs):
        """
        Call all hook functions registered under the specified name.
        
        Args:
            name: The name to call hooks for.
            **kwargs: Keyword arguments to pass to the hook functions.
            
        Returns:
            The return value of the last hook function called, or the 'ret' keyword
            argument if no hooks are found.
        """
        if name in self.hook_dict:
            self.called[name] += 1
            for function in self.hook_dict[name]:
                ret = function(**kwargs)
            if len(self.hook_dict[name]) > 1:
                last = self.hook_dict[name][-1]
                # print(f'The last returned value comes from func {last}')
            return ret
        else:
            return kwargs['ret']

    def fork(self, name):
        """
        Create a new HookManager for a subset of hooks with a common prefix.
        
        Args:
            name: The prefix to filter hooks by.
            
        Returns:
            A new HookManager instance with filtered hooks.
            
        Raises:
            ValueError: If a fork with the same name already exists.
        """
        if name in self.forks:
            raise ValueError(f'Forking with the same name is not allowed. Already forked with {name}.')
        filtered_hooks = [(k[len(name)+1:], v) for k, v in self.hook_dict.items() if k.startswith(name+'.')]
        filtered_hooks_d = defaultdict(list)
        for i, j in filtered_hooks:
            if isinstance(j, list):
                filtered_hooks_d[i].extend(j)
            else:
                filtered_hooks_d[i].append(j)
        new_hook = HookManager(filtered_hooks_d)
        self.forks[name] = new_hook
        return new_hook

    def fork_iterative(self, name, iteration):
        """
        Create a new HookManager for a subset of hooks with iteration-specific naming.
        
        This method handles both specific iterations (e.g., 'name.1') and wildcard
        iterations (e.g., 'name.*').
        
        Args:
            name: The base name for the hooks.
            iteration: The specific iteration number.
            
        Returns:
            A new HookManager instance with filtered hooks.
        """
        filtered_hooks = [(k[len(name+'.'+str(iteration))+1:], v) for k, v in self.hook_dict.items() if k.startswith(name+'.'+str(iteration)+'.')]
        filtered_hooks += [(k[len(name+'.*')+1:], v) for k, v in self.hook_dict.items() if k.startswith(name+'.*.')]
        filtered_hooks_d = defaultdict(list)
        for i, j in filtered_hooks:
            if isinstance(j, list):
                filtered_hooks_d[i].extend(j)
            else:
                filtered_hooks_d[i].append(j)
        new_hook = HookManager(filtered_hooks_d)
        self.forks[name+'.'+str(iteration)] = new_hook
        return new_hook

    def finalize(self):
        """
        Verify that all registered hooks have been called at least once.
        
        Raises:
            ValueError: If any registered hook was never called.
        """
        for name in self.hook_dict.keys():
            if self.called[name] == 0:
                raise ValueError(f'Hook {name} was registered but never used!')