from typing import Callable, Dict, Iterable, Tuple, Any, Optional
from collections import OrderedDict
import logging
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .utils import running_stat, from_config
from .primitives import Identity, DEFAULT_PRIMITIVE_COSTS, PRIMITIVES, DEFAULT_PRIMITIVES

from timm import utils


_logger = logging.getLogger(__name__)


class Statistics(nn.Module):

    def __init__(self, expert_keys, dim, stat_func, stat_update_function):
        super().__init__()

        self.dim = dim
        self.expert_keys = expert_keys
        self.stat_func = stat_func
        self.stat_update_function = stat_update_function

        # Register buffers
        for expert_key in expert_keys:
            self.register_buffer(expert_key, torch.zeros(dim))

    def forward(self, x: torch.Tensor, expert_key: str):
        # Update the buffer
        batch_stat = self.stat_func(x)
        assert batch_stat.shape[0] == self.dim
        setattr(self, expert_key, self.stat_update_function(getattr(self, expert_key), batch_stat))
        return x
    
    def __repr__(self):
        return f"Statistics(experts={self.expert_keys}, dim={self.dim})"


class TaskEmbedding(nn.Module):

    def __init__(self, base, final_dim, num_stages=4, init_seed=None):
        super().__init__()
        self.base = base
        self.num_stages = num_stages
        dims = [base*(2**i) for i in range(num_stages)]

        if init_seed is not None:
            # Seed the seed to ensure same initialization for every task
            _logger.info(f"Setting seed to {init_seed}")
            torch_state = torch.get_rng_state()
            torch_cuda_states = torch.cuda.get_rng_state_all()
            np_state = np.random.get_state()
            rand_state = random.getstate()

            utils.random_seed(init_seed, 0)

        self.latent = nn.Parameter(torch.randn(1, dims[0]))
        _logger.info(f"Latent: {self.latent}")

        self.mlps = nn.Sequential(*[
            nn.Sequential(*[
                nn.Linear(dims[i], dims[i+1]),
                nn.ReLU()
            ])
        for i in range(num_stages-1)])
        self.final_linear = nn.Linear(dims[-1], final_dim)

        if init_seed is not None:
            # Reset the states
            torch.set_rng_state(torch_state)
            torch.cuda.set_rng_state_all(torch_cuda_states)
            np.random.set_state(np_state)
            random.setstate(rand_state)

    @torch.cuda.amp.autocast(enabled=False)
    def forward(self, x, mode="key-products"):
        """_summary_

        :param x: BxD
        :type x: _type_
        """
        # Generate the embedding
        latent = self.final_linear(self.mlps(self.latent))
        # Normalize
        embedding = F.normalize(latent, dim=-1) # 1xD

        if mode == "key-products":
            # Compute the key products
            key_products = torch.matmul(F.normalize(x.type(embedding.dtype), dim=1), embedding.T) # Bx1
            return key_products
        return embedding


class ContinualMoE(nn.Module):

    def __init__(self, 
                 backbone: nn.Module, 
                 config: dict, 
                 iter_backbone: Callable[[nn.Module], Iterable[Tuple[nn.Module, nn.Module]]],
                 head_factory: Callable[[nn.Module], nn.ModuleList],
                 op_factory_generator: Callable[[nn.Module, str], Callable],
                 stat_funcs: Dict[str, Callable[[torch.Tensor], torch.Tensor]]=None, 
                 stat_update_function: Callable[[torch.Tensor, torch.Tensor, Optional[Any]], torch.Tensor]=running_stat, 
                 initial_expert_ids: str="expert_0", 
                 encoder_nf=20,
                 init_seed=0,
                 task_embed_exit_layer=13, # last embedding layer
                 **kwargs):
        """ArtiHippo wrapper.

        :param backbone: Backbone model
        :type backbone: nn.Module
        :param config: Config file from which the experts are initialized
        :type config: dict
        :param iter_backbone: A function returning the memory and statistics layers of the backbone model
        :type iter_backbone: Callable[[nn.Module], Iterable[Tuple[nn.Module, nn.Module]]]
        :param head_factory: A function returning a list of task classification heads. The function should take the backbone model as argument.
        :type head_factory: Callable[[nn.Module], nn.ModuleList]
        :param op_factory_generator: Function which returns the operator factory function for a primitive. The function should take the memory layer and the primitive name as arguments.
        :type op_factory_generator: Callable[[nn.Module, str], Callable]
        :param stat_funcs: Dictionary of statistic name and computing function, defaults to None
        :type stat_funcs: Dict[str, Callable[[torch.Tensor], torch.Tensor]], optional
        :param stat_update_function: Statistic update function across batches, defaults to running_stat
        :type stat_update_function: Callable[[torch.Tensor, torch.Tensor, Optional[Any]], torch.Tensor], optional
        :param initial_expert_ids: IDs for the initial identity experts per layer, defaults to "expert_0"
        :type initial_expert_ids: str, optional
        """
        super().__init__()
        
        self.backbone = backbone
        self.iter_backbone = iter_backbone
        self.stat_funcs = stat_funcs
        self.stat_update_function = stat_update_function

        self.initial_expert_ids = initial_expert_ids

        # Init experts based on the config
        self.init_experts(config, op_factory_generator, **kwargs)
        if stat_funcs is not None:
            self.init_statistics()

        # One head per task: Should be an nn.ModuleList
        self.heads = head_factory(backbone)
        # Copy the weights for heads[0] from the backbone
        self.heads[0].load_state_dict(self.backbone.head.state_dict())
        self.final_dim = self.heads[0].in_features
        self.init_task_embeddings(encoder_nf, init_seed)

        self.grad_off()

        self.preprocessing_hook_handles = []
        self.postprocessing_hook_handles = []
        self.stat_hook_handles = []
        self._conditional_execution_hooks_registered = False

        self.task_idx = None
        self.forward_state_args = None
        self._expert_ids_to_use = None

        self.task_embed_exit_layer = task_embed_exit_layer

        self.register_conditional_execution_hooks()

    def init_task_embeddings(self, encoder_nf, init_seed):
        # Task encoders
        self.task_embeddings = nn.ModuleList([TaskEmbedding(encoder_nf, self.final_dim, num_stages=4, init_seed=init_seed+tsk+1) for tsk in range(len(self.heads))])

    def grad_off(self):
        self.experts.requires_grad_(False)
        self.task_embeddings.requires_grad_(False)
        self.heads.requires_grad_(False)

    def set_task_idx(self, task_idx, train=False):

        # assert not ((task_idx == 0) and train), "Cannot train the parameters for task 0, i.e., the bakcbone model."

        self.task_idx = task_idx
        self.experts.requires_grad_(False)
        self.task_embeddings.requires_grad_(False)
        self.heads.requires_grad_(False)
        
        if task_idx > 0:
            self.heads[task_idx].requires_grad_(train)
            for l, expert_ids in enumerate(self.task_to_expert_map):
                expert_id = expert_ids[task_idx]
                expert_op = self.experts[l][expert_id]
                # Make sure that the expert is not shared with other tasks
                if len(expert_op.associated_tasks) == 1:
                    expert_op.requires_grad_(train)

        self.task_embeddings[self.task_idx].requires_grad_(train)

    def register_conditional_execution_hooks(self):

        # Conditional execution for the backbone layer which serves as the task memory.
        # Only the backbone needs cnditional execution
        for l, (layer_name, memory_layer, dim, primitives) in enumerate(self.iter_backbone(self.backbone, mode="memory")):

            preprocesing_hook_handle = memory_layer.register_forward_pre_hook(self.preprocessing_hook(l))
            self.preprocessing_hook_handles.append(preprocesing_hook_handle)

            # Register the postprocessing hook
            postprocessing_hook_handle = memory_layer.register_forward_hook(self.postprocessing_hook(l))
            self.postprocessing_hook_handles.append(postprocessing_hook_handle)

        self._conditional_execution_hooks_registered = True

    def register_stat_hooks(self, task_idx: int):
        
        assert self._conditional_execution_hooks_registered, "Conditional execution hooks must be registered before registering the stat hooks."

        self.stat_hook_handles = []
        for l, (layer_name, layer, dim) in enumerate(self.iter_backbone(self.backbone, mode="statistics")):
            layer_stat_hooks = dict()
            # Register the hooks for the experts of the current task
            expert_id = self.task_to_expert_map[l][task_idx]
            expert_op = self.experts[l][expert_id]
            # Only register the hook if the expert is unique to the current task
            # This way, the stats for other experts will never be calculated
            if len(expert_op.associated_tasks) == 1 and layer is not None:
                if task_idx == 0:
                    expert_stat_hook_handle = layer.register_forward_hook(self.stat_hook(l, expert_id))    
                else:
                    expert_stat_hook_handle = expert_op.register_forward_hook(self.stat_hook(l, expert_id))
                layer_stat_hooks[expert_id] = expert_stat_hook_handle

            # Save the handles since we need to remove the hooks later
            self.stat_hook_handles.append(layer_stat_hooks)

    def remove_stat_hooks(self):

        for l, layer_stat_hooks in enumerate(self.stat_hook_handles):
            for expert_id, stat_hook_handle in layer_stat_hooks.items():
                stat_hook_handle.remove()
        self.stat_hook_handles = []

    def init_experts_from_empty_config(self):

        experts = []
        for layer, (layer_name, memory_layer, dim, primitives) in enumerate(self.iter_backbone(self.backbone, mode="memory")):
            layer_experts = dict()
            layer_task_to_expert_map = dict()

            expert_id = "expert_0"
            op = Identity(size=DEFAULT_PRIMITIVE_COSTS["identity"])
            op.associate_task(0)
            layer_experts[self.initial_expert_ids] = op
            self.backbone_experts.append(self.initial_expert_ids)

            layer_task_to_expert_map.update({task: expert_id for task in op.associated_tasks})

            experts.append(nn.ModuleDict(layer_experts))

            self.task_to_expert_map.append(layer_task_to_expert_map)

        self.experts = nn.ModuleList(experts)

    def init_experts(self, 
                     config: dict, 
                     op_factory_generator, # TODO: Typing
                     **kwargs):

        self.task_to_expert_map = []
        self.backbone_experts = []

        if len(config) == 0:
            self.init_experts_from_empty_config()
            return
    
        experts = []
        for layer, (layer_name, memory_layer, dim, primitives) in enumerate(self.iter_backbone(self.backbone, mode="memory")):
            
            layer_config = config[layer]
            layer_experts = dict()
            layer_task_to_expert_map = dict()
            for expert_id, expert_config in layer_config.items():
                
                primitive = expert_config["primitive"]

                # "identity" is assigned a special status: it only serves to
                # pass the activations of the memory layer of the backbone model
                # This makes it easier for sampling.
                if primitive == "identity":
                    self.backbone_experts.append(expert_id)
                op = from_config(expert_config, memory_layer, op_factory_generator(memory_layer, primitive), **kwargs)
                layer_experts[expert_id] = op

                associated_tasks = expert_config["associated_tasks"]

                # Associate the tasks with experts
                layer_task_to_expert_map.update({task: expert_id for task in associated_tasks})

            experts.append(nn.ModuleDict(layer_experts))
            self.task_to_expert_map.append(layer_task_to_expert_map)
        self.experts = nn.ModuleList(experts)

    def init_statistics(self):
        
        # Track all the statistics
        for stat, stat_func in self.stat_funcs.items():
            stats = []
            # TODO: stats only when defined
            for l, (layer_name, layer, dim) in enumerate(self.iter_backbone(self.backbone, mode="statistics")):
                
                _stat = Statistics(list(self.experts[l].keys()), dim, stat_func, self.stat_update_function)
                stats.append(_stat)

            setattr(self, stat, nn.ModuleList(stats))

    def determine_computation(self, l: int, expert_id: str):

        primitive = self.experts[l][expert_id].primitive

        execute_op = False
        if primitive in ["skip", "new"]:
            # Don't process using the backbone memory layer
            execute_op = False
        elif primitive == "identity":
            # Backbone model layer being used
            execute_op = True
        elif primitive in ["reuse", "adapt"]:
            # Find if any of the experts depend on the backbone memory layer output
            parent_expert_id = self.experts[l][expert_id].parent_expert_id
            execute_op = self.determine_computation(l, parent_expert_id)
        return execute_op

    def preprocessing_hook(self, l: int):

        def preprocess(module, input):
            
            assert isinstance(input, tuple), "The forward function of the memory layer should be decorated by \"artihippo.layers.conditional_forward\"."
            
            # Is the expert id to use defined by the forward/nas method?
            expert_id = self._expert_ids_to_use[l]
            expert_id = self.task_to_expert_map[l][self.task_idx] if expert_id is None else expert_id
            execute_op = self.determine_computation(l, expert_id)
            
            return input, {"execute_op": execute_op}

        return preprocess

    def expert_forward(self, l: int, expert_id: str, expert_input, expert_output: torch.Tensor):

        expert_op = self.experts[l][expert_id]
        primitive = expert_op.primitive
        # Only perform computation if skip, adapt and new
        # If the expert is an identity expert, the x is returned as is
        if primitive in ["reuse", "adapt"]:
            parent_expert_id = expert_op.parent_expert_id
            parent_expert_output = self.expert_forward(l, parent_expert_id, expert_input, expert_output)
            if primitive == "adapt":
                expert_output = expert_op(expert_input, parent_expert_output, **self.forward_state_args)
            else:
                expert_output = expert_op(parent_expert_output, **self.forward_state_args)
        elif primitive in ["skip", "new", "identity"]:
            expert_output = expert_op(expert_output, **self.forward_state_args)
        else:
            raise ValueError(f"Primitive \"{primitive}\" is not supported.")
        return expert_output

    def postprocessing_hook(self, l: int):

        def postprocess(module, input, output):
            # Is the expert id to use defined by the forward/nas method?
            expert_id = self._expert_ids_to_use[l]
            expert_id = self.task_to_expert_map[l][self.task_idx] if expert_id is None else expert_id
            # The preprocessing hook will shortcut x in case of new and skip
            # i.e. output == input
            # Input will be modified by preprocessing func and conditional_compute wrapper
            module_inputs, forward_modifiers = input
            _input = module_inputs[0] # For now, assume only one input
            expert_output = self.expert_forward(l, expert_id, _input, output)
            return expert_output

        return postprocess

    def stat_hook(self, l: int, expert_id: str):

        def calculate_stats(module, input, output):
            
            with torch.no_grad():
                for stat in self.stat_funcs.keys():
                    stat_module = getattr(self, stat)[l]
                    stat_module(output, expert_id)

        return calculate_stats

    def state_dict(self, *args, destination=None, prefix='', keep_vars=False):

        _state_dict = super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
        # Remove the backbone parameters
        keys_wo_backbone = [key for key in _state_dict.keys() if "backbone" not in key]

        return OrderedDict({key: _state_dict[key] for key in keys_wo_backbone})

    def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = False):
        # Note the change of default value of strict from True to False
        return super().load_state_dict(state_dict, strict)

    def set_forward_args(self, expert_ids=None, **kwargs):
        self.forward_state_args = kwargs
        self._expert_ids_to_use = expert_ids or [None for _ in range(len(self.experts))]

    def unset_forward_args(self):
        self.forward_state_args = None
        self._expert_ids_to_use = [None for _ in range(len(self.experts))]

    def encoding(self, x, mode="train", do_key_products=True):

        is_training = self.training

        # Set backbone to eval mode
        self.backbone.eval()
        # Encoding using the backbone model
        # Back up the expert_ids_to_use
        expert_ids_to_use = self._expert_ids_to_use
        self._expert_ids_to_use = self.backbone_experts
        encoding = self.partial_backbone_forward(x)
        
        # Restore
        self._expert_ids_to_use = expert_ids_to_use
        if is_training:
            self.backbone.train()
            
        if mode == "train":
            key_products = self.task_embeddings[self.task_idx](encoding.detach())
            predicted_task_idxs = None
        else:
            key_products = torch.cat([self.task_embeddings[task_idx](encoding.detach()) for task_idx in range(len(self.heads))], dim=-1)
            predicted_task_idxs = torch.argmax(key_products, dim=-1)

        return encoding, key_products, predicted_task_idxs

    def partial_backbone_forward(self, x):
        x = self.backbone.patch_embed(x)
        x = self.backbone._pos_embed(x)
        x = self.backbone.norm_pre(x)
        for i, blk in enumerate(self.backbone.blocks):
            x = blk(x)
            if i == self.task_embed_exit_layer:
                x = x[:, 0, :]
                return x
        x = self.backbone.norm(x)
        x = self.backbone.forward_head(x, pre_logits=True)
        return x

    def backbone_forward(self, x):
        encoding = self.backbone.forward_features(x)
        encoding = self.backbone.forward_head(encoding, pre_logits=True)
        return encoding

    def forward(self, x, expert_ids: int=None, pre_logits: bool=False, do_task_encoding=True, encoder_only=False, mode="train", **kwargs):

        self.set_forward_args(expert_ids=expert_ids, **kwargs)
        key_products, predicted_task_idxs = None, None
        if do_task_encoding:
            encoding, key_products, predicted_task_idxs = self.encoding(x, mode=mode)

        if encoder_only:
            return encoding, key_products, predicted_task_idxs
        
        if mode == "eval":
            # Set the task idx
            assert predicted_task_idxs is not None, "The predicted task idxs must be defined in eval mode."
            assert x.shape[0] == 1, "The batch size must be 1 in eval mode."
            self.task_idx = predicted_task_idxs[0].item()

        # If not in eval mode, self.task_idx will be set anyway
        x = self.backbone_forward(x)

        if not pre_logits:
            x = self.heads[self.task_idx](x)

        self.unset_forward_args()

        return x, key_products, predicted_task_idxs

    @property
    def config(self):
        _config = []
        for l, layer_experts in enumerate(self.experts):
            layer_expert_configs = dict()
            for expert_id, expert_op in layer_experts.items():
                layer_expert_configs[expert_id] = expert_op.config
            _config.append(layer_expert_configs)
        
        return _config


class Supernet(ContinualMoE):

    def __init__(self, 
                 task_idx: int,
                 backbone: nn.Module, 
                 config: dict, 
                 iter_backbone: Callable[[nn.Module], Iterable[Tuple[nn.Module, nn.Module]]],
                 head_factory: Callable[[nn.Module], nn.ModuleList],
                 op_factory_generator: Callable[[nn.Module, str], Callable],
                 sampler, # TODO: Typing
                 stat_funcs: Dict[str, Callable[[torch.Tensor], torch.Tensor]]=None, 
                 stat_update_function: Callable[[torch.Tensor, torch.Tensor, Optional[Any]], torch.Tensor]=running_stat, 
                 encoder_nf=20,
                 init_seed=0,
                 task_embed_exit_layer=13, # last embedding layer
                 initial_expert_ids: str="expert_0", 
                 **kwargs):
        """Supernet wrapper used for NAS.

        :param task_idx: Index of the task being trained
        :type task_idx: int
        :param backbone: Backbone model
        :type backbone: nn.Module
        :param config: Config file from which the experts are initialized
        :type config: dict
        :param iter_backbone: A function returning the memory and statistics layers of the backbone model
        :type iter_backbone: Callable[[nn.Module], Iterable[Tuple[nn.Module, nn.Module]]]
        :param head_factory: A function returning a list of task classification heads. The function should take the backbone model as argument.
        :type head_factory: Callable[[nn.Module], nn.ModuleList]
        :param op_factory_generator: Function which returns the operator factory function for a primitive. The function should take the memory layer and the primitive name as arguments.
        :type op_factory_generator: Callable[[nn.Module, str], Callable]
        :param sampler: Sampler used to sample the operators for each layer
        :type sampler: Callable
        :param stat_funcs: Dictionary of statistic name and computing function, defaults to None
        :type stat_funcs: Dict[str, Callable[[torch.Tensor], torch.Tensor]], optional
        :param stat_update_function: Statistic update function across batches, defaults to running_stat
        :type stat_update_function: Callable[[torch.Tensor, torch.Tensor, Optional[Any]], torch.Tensor], optional
        :param initial_expert_ids: IDs for the initial identity experts per layer, defaults to "expert_0"
        :type initial_expert_ids: str, optional
        """

        # NAS supernet is only used after training the initial backbone
        assert task_idx > 0

        # Initialize the previous learned model with the default arguments
        super().__init__(backbone=backbone, config=config, 
                        iter_backbone=iter_backbone, head_factory=head_factory, 
                        op_factory_generator=op_factory_generator, 
                        stat_funcs=stat_funcs,
                        stat_update_function=stat_update_function, 
                        encoder_nf=encoder_nf,
                        init_seed=init_seed,
                        task_embed_exit_layer=task_embed_exit_layer,
                        initial_expert_ids=initial_expert_ids, **kwargs)

        # Override the default arguments with NAS arguments of the same name
        # This is for the modules that show different behavior during NAS e.g. the hybrid adapter
        all_arguments = list(kwargs.keys())
        overridden_kwargs = dict()
        for argument in all_arguments:
            if "nas" in argument:
                _argument_name = argument.split("_", maxsplit=1)
                assert len(_argument_name) == 2, "The overriding arguments should be of the form \"nas_<argument>\"."
                argument_name = _argument_name[1]
                overridden_kwargs[argument_name] = kwargs[argument]
            else:
                overridden_kwargs[argument] = kwargs[argument]
        
        self.task_idx = task_idx

        self.init_search_space_ops(op_factory_generator, **overridden_kwargs)
        self.unfreeze()

        self.learned_ops = None

        if sampler is not None:
            # Initialize the sampler
            self.init_samplers(sampler)

        self.seed = None

    def init_search_space_ops(self, 
                              op_factory_generator, # TODO: Typing
                              **kwargs):

        self.expert_list = []
        self.nas_expert_list = []
        self.nas_search_space = []

        for layer, (layer_name, memory_layer, dim, primitives) in enumerate(self.iter_backbone(self.backbone, mode="memory")):

            primitives_to_use = primitives or DEFAULT_PRIMITIVES

            # Store the previous experts first
            self.expert_list.append(list(self.experts[layer].keys()))

            # Hide the skip experts form NAS since we don't want to reuse them
            _experts = [e for e in self.expert_list[layer] if self.experts[layer][e].primitive != "skip"]

            # Save for future reference
            self.nas_expert_list.append(_experts)
            layer_nas_experts = dict()
            n_experts = len(_experts)
            num_ops = 0 # Reuse always present
            if "reuse" in primitives_to_use:
                num_ops += n_experts
            if "adapt" in primitives_to_use:
                num_ops += n_experts
            if "new" in primitives_to_use:
                num_ops += 1
            if "skip" in primitives_to_use:
                num_ops += 1
            layer_nas_search_space = []

            if "reuse" in primitives_to_use:
                for i, expert_id in enumerate(_experts):
                    # Create the Reuse expert
                    _expert_id = "reuse_"+expert_id
                    layer_nas_experts[_expert_id] = PRIMITIVES["reuse"](expert_id, size=DEFAULT_PRIMITIVE_COSTS["reuse"], **kwargs)
                    layer_nas_search_space.append(_expert_id)
            if "adapt" in primitives_to_use:
                for i, expert_id in enumerate(_experts):
                    # Create the Adapt expert
                    _expert_id = "adapt_"+expert_id
                    layer_nas_experts[_expert_id] = PRIMITIVES["adapt"](expert_id, memory_layer, op_factory_generator(memory_layer, "adapt"), size=DEFAULT_PRIMITIVE_COSTS["adapt"], **kwargs)
                    layer_nas_search_space.append(_expert_id)

            # Create the new expert
            if "new" in primitives_to_use:
                layer_nas_experts["new"] = PRIMITIVES["new"](None, memory_layer, op_factory_generator(memory_layer, "new"), size=DEFAULT_PRIMITIVE_COSTS["new"], **kwargs)
                layer_nas_search_space.append("new")

            # Create the skip expert
            if "skip" in primitives_to_use:
                layer_nas_experts["skip"] = PRIMITIVES["skip"](None, memory_layer, op_factory_generator(memory_layer, "skip"), size=DEFAULT_PRIMITIVE_COSTS["skip"], **kwargs)
                layer_nas_search_space.append("skip")

            assert len(layer_nas_search_space) == num_ops, f"Expected {num_ops}, got {len(layer_nas_search_space)}"

            self.nas_search_space.append(layer_nas_search_space)

            self.experts[layer].update(nn.ModuleDict(layer_nas_experts))

    def init_samplers(self, sampler_class):
        self.samplers = []
        for l, (layer_name, memory_layer, dim, primitives) in enumerate(self.iter_backbone(self.backbone, mode="memory")):
            layer_sampler = sampler_class(primitives, self.nas_search_space[l])
            self.samplers.append(layer_sampler)

    def _sample_nas_ops(self, seed=0, constant_sample_per_worker=True, worker_id=0, **kwargs):
        
        # Setting the seed here ensures that different expert_ids are sampled at each forward pass
        if self.seed is not None:
            assert seed != self.seed, "The seed must be different from the previous seed"
        self.seed = seed

        if not constant_sample_per_worker:
            if seed is not None:
                self.seed += worker_id

        # Ensure that a different seed is used for each layer (if distributed)
        layer_sampling_seeds = [self.seed + _l if seed is not None else seed for _l in range(len(self.samplers))]
        _sampled_ops = [sampler.sample_op(seed=layer_sampling_seeds[_l], **kwargs) for _l, sampler in enumerate(self.samplers)]
        if seed is not None:
            self.seed += len(self.samplers) - 1 # The number of times the seed is incremented
        
        return _sampled_ops

    def unfreeze(self):
        self.heads[self.task_idx].requires_grad_(True)
    
    def arch_size(self, nas_expert_ids):

        size = 0
        for layer, expert_id in enumerate(nas_expert_ids):
            op_size = self.experts[layer][expert_id].size
            size += op_size
        return size
    
    def expert_params(self):
        # Return the parameters in the experts and the heads
        assert self.task_idx is not None
        parameters = []
        for l, layer_nas_ops in enumerate(self.nas_search_space):
            for nas_op in layer_nas_ops:
                parameters = parameters + list(self.experts[l][nas_op].parameters())
        return parameters + list(self.heads[self.task_idx].parameters())

    def choose_ops(self, expert_ids: list[str]):
        self.learned_ops = expert_ids
        for layer, _expert_id in enumerate(expert_ids):
            
            chosen_op = self.experts[layer][_expert_id]
            expert_id = _expert_id if chosen_op.primitive != "reuse" else chosen_op.parent_expert_id
            layer_expert_ids = set(self.expert_list[layer])
            
            # Remove all the excess experts
            layer_expert_ids.add(expert_id)
            excess_expert_ids = list(set(self.experts[layer].keys()).difference(layer_expert_ids))
            for excess_expert_id in excess_expert_ids:
                self.experts[layer].pop(excess_expert_id)

            _expert_id = expert_id
            if expert_id not in self.expert_list[layer]:
                # adapt, skip, new
                # Generate a new expert id
                _expert_id = f"expert_{len(self.expert_list[layer])}"
                op = self.experts[layer].pop(expert_id)
                self.experts[layer][_expert_id] = op
                self.expert_list[layer].append(_expert_id)
            self.experts[layer][_expert_id].associate_task(self.task_idx)

    def set_forward_args(
            self, 
            task_idx=None, 
            expert_ids=None, 
            seed: int=0, 
            constant_sample_per_worker: bool=True,
            verbose: bool=False, 
            worker_id: int=None,
            **kwargs):
        
        self.forward_state_args = kwargs
        if expert_ids is not None:
            # expert_ids will be defined when calculating the task_stats or
            # running the evolutionary search
            assert task_idx is None
        elif task_idx is not None:
            assert expert_ids is None
            expert_ids = [self.task_to_expert_map[l][task_idx] for l in range(len(self.task_to_expert_map))]
        else:
            assert self.samplers is not None, "The sampler must be defined when expert_ids are not defined."
            expert_ids = self._sample_nas_ops(seed=seed, constant_sample_per_worker=constant_sample_per_worker, worker_id=worker_id, **kwargs)

        self._expert_ids_to_use = expert_ids
        if verbose:
            _logger.info(f"[Worker ID: {worker_id}] NAS expert ids: {expert_ids}")

    def unset_forward_args(self):
        self.forward_state_args = None
        self._expert_ids_to_use = [None for _ in range(len(self.experts))]

    def forward(
            self, x, task_idx: int=None,
            expert_ids: list[int]=None, pre_logits: bool=False,
            seed: int=0, constant_sample_per_worker: bool=True,
            verbose: bool=False, worker_id: int=None,
            **kwargs):

        self.set_forward_args(
            task_idx=task_idx, expert_ids=expert_ids, seed=seed, 
            constant_sample_per_worker=constant_sample_per_worker, 
            verbose=verbose, worker_id=worker_id, **kwargs
        )

        x = self.backbone_forward(x)

        # Reset
        self.unset_forward_args()
        x = x if pre_logits else self.heads[self.task_idx](x)
        return x

    @property
    def config(self):
        assert self.learned_ops is not None, "Final operators must be chosen for config to be valid."
        return super().config


class SimilaritySupernet(Supernet):

    def __init__(self, 
                 task_idx: int,
                 backbone: nn.Module, 
                 config: dict, 
                 iter_backbone: Callable[[nn.Module], Iterable[Tuple[nn.Module, nn.Module]]],
                 head_factory: Callable[[nn.Module], nn.ModuleList],
                 op_factory_generator, # TODO: Typing
                 sampler,
                 stat_funcs: Callable[[torch.Tensor], torch.Tensor]=None, 
                 stat_update_function: Callable[[torch.Tensor, torch.Tensor, Optional[Any]], torch.Tensor]=running_stat, 
                 similarity_stat=None,
                 encoder_nf=20,
                 init_seed=0,
                 task_embed_exit_layer=13, # last embedding layer
                 initial_expert_ids: str="expert_0", 
                 normalize_similarities: bool=True,
                 **kwargs):
        """Supernet used with similarity-oriented NAS

        :param task_idx: Index of the task being trained
        :type task_idx: int
        :param backbone: Backbone model
        :type backbone: nn.Module
        :param config: Config file from which the experts are initialized
        :type config: dict
        :param iter_backbone: A function returning the memory and statistics layers of the backbone model
        :type iter_backbone: Callable[[nn.Module], Iterable[Tuple[nn.Module, nn.Module]]]
        :param head_factory: A function returning a list of task classification heads. The function should take the backbone model as argument.
        :type head_factory: Callable[[nn.Module], nn.ModuleList]
        :param op_factory_generator: Function which returns the operator factory function for a primitive. The function should take the memory layer and the primitive name as arguments.
        :type op_factory_generator: Callable[[nn.Module, str], Callable]
        :param sampler: Sampler used to sample the operators for each layer
        :type sampler: Callable
        :param stat_funcs: Dictionary of statistic name and computing function, defaults to None
        :type stat_funcs: Dict[str, Callable[[torch.Tensor], torch.Tensor]], optional
        :param stat_update_function: Statistic update function across batches, defaults to running_stat
        :type stat_update_function: Callable[[torch.Tensor, torch.Tensor, Optional[Any]], torch.Tensor], optional
        :param initial_expert_ids: IDs for the initial identity experts per layer, defaults to "expert_0"
        :type initial_expert_ids: str, optional
        :param normalize_similarities: Normalize similarities across experts, defaults to True
        :type normalize_similarities: bool, optional
        """

        super().__init__(task_idx=task_idx,
                       backbone=backbone,
                       config=config,
                       iter_backbone=iter_backbone,
                       head_factory=head_factory,
                       op_factory_generator=op_factory_generator,
                       sampler=None,
                       stat_funcs=stat_funcs,
                       stat_update_function=stat_update_function,
                       encoder_nf=20,
                       init_seed=0,
                       task_embed_exit_layer=13, # last embedding layer
                       initial_expert_ids=initial_expert_ids,
                       **kwargs)
        
        self.normalize_similarities = normalize_similarities
        self.sampler = sampler
        self.similarities = None
        self.similarity_stat = similarity_stat
        assert self.similarity_stat is not None, "Similarity stat must be defined."
        self.task_similarity_stat = f"task_{self.similarity_stat}"

        if stat_funcs is not None:
            self.init_task_statistics()

        self.seed = None
        self.sampler = sampler

    def init_task_statistics(self):

        for stat, stat_func in self.stat_funcs.items():
            stats = []
            for l, (layer_name, layer, dim) in enumerate(self.iter_backbone(self.backbone, mode="statistics")):
                layer_stats = None
                if layer is not None:
                    layer_stats = Statistics(list(self.experts[l].keys()), dim, stat_func, self.stat_update_function)
                stats.append(layer_stats)

            setattr(self, f"task_{stat}", nn.ModuleList(stats))
    
    def calculate_similarities(self):
        
        with torch.no_grad():
            self._similarities = []
            min_similarity = 2.
            max_similarity = -2.
            for l, (layer_name, layer, dim) in enumerate(self.iter_backbone(self.backbone, mode="statistics")):
                layer_similarities = None
                if layer is not None:
                    # Hide skip experts from NAS operation
                    layer_expert_list = self.nas_expert_list[l]
                    stats = getattr(self, self.similarity_stat)[l]
                    task_stats = getattr(self, self.task_similarity_stat)[l]
                    dtype = getattr(stats, layer_expert_list[0]).dtype
                    layer_similarities = torch.zeros(len(layer_expert_list), dtype=dtype)
                    for i, expert_id in enumerate(layer_expert_list):
                        stored_stat = getattr(stats, expert_id)
                        task_stat = getattr(task_stats, expert_id)
                        similarity = F.cosine_similarity(stored_stat, task_stat, dim=0)
                        layer_similarities[i] = similarity
                        if similarity < min_similarity:
                            min_similarity = similarity.item()
                        if similarity > max_similarity:
                            max_similarity = similarity.item()
                self._similarities.append(layer_similarities)

            self.min_similarity = min_similarity
            self.max_similarity = max_similarity

            if self.normalize_similarities:
                self.calculate_normalized_similarities()
                self.similarities = self._normalized_similarities
            else:
                self.similarities = self._similarities

        # Logging for sanity check
        for l, (layer_name, layer, dim) in enumerate(self.iter_backbone(self.backbone, mode="statistics")):
            if layer is not None:
                layer_expert_list = self.nas_expert_list[l]
                layer_similarities = self.similarities[l]
                _logger.info(f"{l} - Experts: {layer_expert_list}")
                _logger.info(f"{l} - Similarities: {layer_similarities}")

        self.init_samplers(self.sampler)

    def register_stat_hooks(self, task_idx: int):
        
        assert self._conditional_execution_hooks_registered, "Conditional execution hooks must be registered before registering the stat hooks."

        self.stat_hook_handles = []
        for l, (layer_name, layer, dim) in enumerate(self.iter_backbone(self.backbone, mode="statistics")):
            layer_stat_hooks = dict()
            # Register the hooks for the experts of the current task
            expert_id = self.task_to_expert_map[l][task_idx]
            expert_op = self.experts[l][expert_id]
            # Only register the hook if the expert was originally learnt for task_idx
            # This way, the stats for other experts will never be calculated
            if expert_op.associated_tasks[0] == task_idx and layer is not None:
                if task_idx == 0:
                    expert_stat_hook_handle = layer.register_forward_hook(self.stat_hook(l, expert_id))    
                else:
                    expert_stat_hook_handle = expert_op.register_forward_hook(self.stat_hook(l, expert_id))
                layer_stat_hooks[expert_id] = expert_stat_hook_handle

            # Save the handles since we need to remove the hooks later
            self.stat_hook_handles.append(layer_stat_hooks)

    def stat_hook(self, l: int, expert_id: str):

        def calculate_stats(module, input, output):
            
            with torch.no_grad():
                for stat, stat_func in self.stat_funcs.items():
                    stat_module = getattr(self, f"task_{stat}")[l]
                    stat_module(output, expert_id)

        return calculate_stats

    def calculate_normalized_similarities(self):
        with torch.no_grad():
            self._normalized_similarities = []
            for l, _similarities in enumerate(self._similarities):
                layer_similarities = None
                if _similarities is not None:
                    layer_similarities = (_similarities - self.min_similarity) / (self.max_similarity - self.min_similarity + 1e-8)
                    layer_similarities = 2. * layer_similarities - 1.
                self._normalized_similarities.append(layer_similarities)

    def init_samplers(self, sampler_class):
        self.samplers = []
        for l, (layer_name, memory_layer, dim, primitives) in enumerate(self.iter_backbone(self.backbone, mode="memory")):
            layer_sampler = sampler_class(primitives, self.similarities[l], self.nas_search_space[l])
            _logger.info(f"{l} - Experts: {layer_sampler.experts}")
            _logger.info(f"{l} - Search Space: {layer_sampler.search_space}")
            _logger.info(f"{l} - Similarities: {layer_sampler.similarities}")
            _logger.info(f"{l} - Search Space Prob: {layer_sampler.search_space_prob}")
            self.samplers.append(layer_sampler)


class DARTSSupernet(Supernet):

    def __init__(self, 
                 task_idx: int,
                 backbone: nn.Module, 
                 config: dict, 
                 iter_backbone: Callable[[nn.Module], Iterable[Tuple[nn.Module, nn.Module]]],
                 head_factory: Callable[[nn.Module], nn.ModuleList],
                 op_factory_generator: Callable[[nn.Module, str], Callable],
                 initial_expert_ids: str="expert_0", 
                 **kwargs):
        super().__init__(
            task_idx, 
            backbone, 
            config, 
            iter_backbone, 
            head_factory, 
            op_factory_generator, 
            None,
            initial_expert_ids=initial_expert_ids, 
            **kwargs
        )

        # Remove the preprocessing hooks since we need to compute the expert
        # all the time
        self.remove_preprocessing_hooks()

        self.init_arch_params()

        self.update_type = None

        # Maintain a cache of the expert outputs to avoid duplicate computation
        self.cache = [dict() for _ in range(len(self.nas_search_space))]

    def remove_preprocessing_hooks(self):
        for l, layer_hook in enumerate(self.preprocessing_hook_handles):
            layer_hook.remove()

    def init_arch_params(self):

        self.alphas = nn.ParameterList()

        for l, layer_search_space in enumerate(self.nas_search_space):
            alpha = nn.Parameter( 1e-3*torch.randn(len(layer_search_space),) )
            self.alphas.append(alpha)

    def arch_parameters(self):
        return self.alphas
    
    def _save_arch_parameters(self):
        self._saved_arch_parameters = [alpha.clone() for alpha in self.alphas]

    def softmax_arch_parameters(self):
        self._save_arch_parameters()
        for alpha in self.alphas:
            alpha.data.copy_(F.softmax(alpha, dim=-1))
            
    def restore_arch_parameters(self):
        for alpha, saved_alpha in zip(self.alphas, self._saved_arch_parameters):
            alpha.data.copy_(saved_alpha)
        del self._saved_arch_parameters
    
    def expert_forward(self, l: int, expert_id: str, expert_input, expert_output: torch.Tensor):

        # Check if this expert has already been calculated
        _expert_output = self.cache[l].get(expert_id, None)
        # print("Cache", _expert_output)
        if _expert_output is not None:
            return _expert_output

        expert_op = self.experts[l][expert_id]
        primitive = expert_op.primitive
        # Only perform computation if skip, adapt and new
        # If the expert is an identity expert, the x is returned as is
        if primitive in ["reuse", "adapt"]:
            parent_expert_id = expert_op.parent_expert_id
            parent_expert_output = self.expert_forward(l, parent_expert_id, expert_input, expert_output)
            if primitive == "adapt":
                expert_output = expert_op(expert_input, parent_expert_output, **self.forward_state_args)
            else:
                expert_output = expert_op(parent_expert_output, **self.forward_state_args)
        elif primitive in ["skip", "new"]:
            expert_output = expert_op(expert_input, **self.forward_state_args)
        elif primitive == "identity":
            expert_output = expert_op(expert_output, **self.forward_state_args)
        else:
            raise ValueError(f"Primitive \"{primitive}\" is not supported.")
        
        # Cache
        self.cache[l][expert_id] = expert_output
        return expert_output
    
    def postprocessing_hook(self, l: int):

        def postprocess(module, inputs, output):

            input = inputs[0]
            
            if self.update_type == "weight":
                # Don't do softmax
                alpha = self.alphas[l]
            elif self.update_type == "alpha":
                alpha = self.alphas[l].softmax(dim=-1)
            else:
                raise ValueError(f"update_type {self.update_type} not supported.")
        
            _out = []
            for i, (expert_id, _alpha) in enumerate(zip(self.nas_search_space[l], alpha)):
                
                expert_output = self.expert_forward(l, expert_id, input, output)
                _out.append(expert_output * _alpha)

            # Clear the cache
            all_ids = list(self.cache[l].keys())
            for expert_id in all_ids:
                del self.cache[l][expert_id]

            return sum(_out)
        
        return postprocess
    
    def forward(self, x, update_type, pre_logits: bool=False, no_head_backbone=False, **kwargs):
        self.forward_state_args = kwargs
        self.update_type = update_type

        if no_head_backbone:
            x = self.backbone.forward(x)
        else:
            x = self.backbone.forward_features(x)
            _pre_logits = self.task_idx != 0
            x = self.backbone.forward_head(x, pre_logits=_pre_logits)

        self.forward_state_args = None

        if no_head_backbone:
            return x
        if self.task_idx > 0:
            x = x if pre_logits else self.heads[self.task_idx](x)
        return x

    def top_arch(self, return_alpha=False):
        expert_ids = []
        top_alphas = []
        for l, (nas_search_space, alphas) in enumerate(zip(self.nas_search_space, self.alphas)):
            max_alpha = alphas.argmax().item()
            nas_expert_id = nas_search_space[max_alpha]
            expert_ids.append(nas_expert_id)
            top_alphas.append(alphas[max_alpha].item())
        if return_alpha:
            return expert_ids, top_alphas
        return expert_ids

                                                         

class MeanVarSimilaritySupernet(Supernet):

    def __init__(self, 
                 task_idx: int,
                 backbone: nn.Module, 
                 config: dict, 
                 iter_backbone: Callable[[nn.Module], Iterable[Tuple[nn.Module, nn.Module]]],
                 head_factory: Callable[[nn.Module], nn.ModuleList],
                 op_factory_generator, # TODO: Typing
                 sampler,
                 stat_funcs: Callable[[torch.Tensor], torch.Tensor]=None, 
                 stat_update_function: Callable[[torch.Tensor, torch.Tensor, Optional[Any]], torch.Tensor]=running_stat, 
                 similarity_stat=None,
                 initial_expert_ids: str="expert_0", 
                 normalize_similarities: bool=True,
                 **kwargs):

        super().__init__(task_idx=task_idx,
                       backbone=backbone,
                       config=config,
                       iter_backbone=iter_backbone,
                       head_factory=head_factory,
                       op_factory_generator=op_factory_generator,
                       sampler=None,
                       stat_funcs=stat_funcs,
                       stat_update_function=stat_update_function,
                       initial_expert_ids=initial_expert_ids,
                       **kwargs)
        
        self.normalize_similarities = normalize_similarities
        self.sampler = sampler
        self.similarities = None
        assert set(["mean", "var"]).intersection(set(stat_funcs.keys())) == 2, "MeanVarSimilaritySupernet requires mean and var statistics."

        if stat_funcs is not None:
            self.init_task_statistics()

        self.seed = None
        self.sampler = sampler

    def init_task_statistics(self):

        for stat, stat_func in self.stat_funcs.items():
            stats = []
            for l, (layer_name, layer, dim) in enumerate(self.iter_backbone(self.backbone, mode="statistics")):
                layer_stats = None
                if layer is not None:
                    layer_stats = Statistics(list(self.experts[l].keys()), dim, stat_func, self.stat_update_function)
                stats.append(layer_stats)

            setattr(self, f"task_{stat}", nn.ModuleList(stats))
    
    def calculate_similarities(self):
        
        with torch.no_grad():
            self._similarities = []
            min_similarity = 2.
            max_similarity = -2.
            for l, (layer_name, layer, dim) in enumerate(self.iter_backbone(self.backbone, mode="statistics")):
                layer_similarities = None
                if layer is not None:
                    # Hide skip experts from NAS operation
                    layer_expert_list = self.nas_expert_list[l]
                    task_means = getattr(self, "task_mean")[l]
                    task_vars = getattr(self, "task_var")[l]
                    dtype = getattr(task_means, layer_expert_list[0]).dtype
                    layer_similarities = torch.zeros(len(layer_expert_list), dtype=dtype)
                    for i, expert_id in enumerate(layer_expert_list):
                        task_mean = getattr(task_means, expert_id)
                        task_var = getattr(task_vars, expert_id)
                        # Coeffifient of variation as a proxy for similarity
                        coeff_of_var = task_var / (torch.abs(task_mean) + 1e-8)
                        similarity = 1. / (coeff_of_var + 1e-8)
                        layer_similarities[i] = similarity
                        if similarity < min_similarity:
                            min_similarity = similarity.item()
                        if similarity > max_similarity:
                            max_similarity = similarity.item()
                self._similarities.append(layer_similarities)

            self.min_similarity = min_similarity
            self.max_similarity = max_similarity

            if self.normalize_similarities:
                self.calculate_normalized_similarities()
                self.similarities = self._normalized_similarities
            else:
                self.similarities = self._similarities

        self.init_samplers(self.sampler)

    def register_stat_hooks(self, task_idx: int):
        
        assert self._conditional_execution_hooks_registered, "Conditional execution hooks must be registered before registering the stat hooks."

        self.stat_hook_handles = []
        for l, (layer_name, layer, dim) in enumerate(self.iter_backbone(self.backbone, mode="statistics")):
            layer_stat_hooks = dict()
            # Register the hooks for the experts of the current task
            expert_id = self.task_to_expert_map[l][task_idx]
            expert_op = self.experts[l][expert_id]
            # Only register the hook if the expert was originally learnt for task_idx
            # This way, the stats for other experts will never be calculated
            if expert_op.associated_tasks[0] == task_idx and layer is not None:
                if task_idx == 0:
                    expert_stat_hook_handle = layer.register_forward_hook(self.stat_hook(l, expert_id))    
                else:
                    expert_stat_hook_handle = expert_op.register_forward_hook(self.stat_hook(l, expert_id))
                layer_stat_hooks[expert_id] = expert_stat_hook_handle

            # Save the handles since we need to remove the hooks later
            self.stat_hook_handles.append(layer_stat_hooks)

    def stat_hook(self, l: int, expert_id: str):

        def calculate_stats(module, input, output):
            
            with torch.no_grad():
                for stat, stat_func in self.stat_funcs.items():
                    stat_module = getattr(self, f"task_{stat}")[l]
                    stat_module(output, expert_id)

        return calculate_stats

    def calculate_normalized_similarities(self):
        with torch.no_grad():
            self._normalized_similarities = []
            for l, _similarities in enumerate(self._similarities):
                layer_similarities = None
                if _similarities is not None:
                    layer_similarities = (_similarities - self.min_similarity) / (self.max_similarity - self.min_similarity + 1e-8)
                    layer_similarities = 2. * layer_similarities - 1.
                self._normalized_similarities.append(layer_similarities)

    def init_samplers(self, sampler_class):
        self.samplers = []
        for l, (layer_name, memory_layer, dim, primitives) in enumerate(self.iter_backbone(self.backbone, mode="memory")):
            layer_sampler = sampler_class(primitives, self.similarities[l], self.nas_search_space[l])
            self.samplers.append(layer_sampler)
