import torch.nn as nn


DEFAULT_PRIMITIVES = ["skip", "reuse", "adapt", "new"]

DEFAULT_PRIMITIVE_COSTS = {
    "skip": -2,
    "reuse": 0,
    "adapt": 1,
    "new": 2,
    "identity": 0
}


class Primitive(nn.Module):

    def __init__(self):

        super(Primitive, self).__init__()

        self.primitive = None
        self.size = None
        self.associated_tasks = []

    @property
    def config(self):

        return {
            "primitive": self.primitive,
            "size": self.size,
            "associated_tasks": self.associated_tasks
        }

    def forward(self):
        raise NotImplementedError

    def associate_task(self, task):
        self.associated_tasks.append(task)

    def associate_tasks(self, tasks):
        assert isinstance(tasks, list)
        self.associated_tasks += tasks
            

# Special expert for backbone
class Identity(Primitive):

    def __init__(self, size):

        super(Identity, self).__init__()

        self.primitive = "identity"
        self.op = None
        self.size = size

    def forward(self, x, **kwargs):
        return x


class Skip(Primitive):

    def __init__(self, backbone_layer, op_factory, size, **kwargs):

        super(Skip, self).__init__()

        self.primitive = "skip"
        self.size = size
        self.op = op_factory(backbone_layer, **kwargs)

    def forward(self, x, **kwargs):

        return self.op(x).detach() # Will prevent the gradients from flowing through


class Reuse(Primitive):
    """Identity operation"""
    def __init__(self, parent_expert_id, size):

        super(Reuse, self).__init__()

        self.primitive = "reuse"
        self.parent_expert_id = parent_expert_id
        self.size = size
        self.op = None

    def forward(self, x, **kwargs):
        # Identity operation
        return x

    @property
    def config(self):

        _config = super().config
        _config["parent_expert_id"] = self.parent_expert_id

        return _config


class Adapt(Primitive):

    def __init__(self, parent_expert_id, backbone_layer, op_factory, size, **kwargs):

        super(Adapt, self).__init__()

        self.primitive = "adapt"
        self.parent_expert_id = parent_expert_id
        self.size = size

        self.op = op_factory(backbone_layer, **kwargs)

    def forward(self, parent_input, parent_output, **kwargs):

        return self.op(parent_input, parent_output, **kwargs)

    @property
    def config(self):

        _config = super().config
        _config["parent_expert_id"] = self.parent_expert_id
        return _config


class New(Primitive):

    def __init__(self, backbone_layer, op_factory, size, **kwargs):

        super(New, self).__init__()

        self.primitive = "new"
        self.size = size
        self.op = op_factory(backbone_layer, **kwargs)

    def forward(self, x, **kwargs):
        
        return self.op(x, **kwargs)


PRIMITIVES = {
    "skip": lambda parent_expert_id, backbone_layer, op_factory, size=DEFAULT_PRIMITIVE_COSTS["skip"], **kwargs: Skip(backbone_layer, op_factory, size),
    "reuse": lambda parent_expert_id, backbone_layer=None, op_factory=None, size=DEFAULT_PRIMITIVE_COSTS["reuse"], **kwargs: Reuse(parent_expert_id, size),
    "adapt": lambda parent_expert_id, backbone_layer, op_factory, size=DEFAULT_PRIMITIVE_COSTS["adapt"], **kwargs: Adapt(parent_expert_id, backbone_layer, op_factory, size, **kwargs),
    "new": lambda parent_expert_id, backbone_layer, op_factory, size=DEFAULT_PRIMITIVE_COSTS["new"], **kwargs: New(backbone_layer, op_factory, size, **kwargs),
    "identity": lambda parent_expert_id, backbone_layer=None, op_factory=None, size=DEFAULT_PRIMITIVE_COSTS["identity"], **kwargs: Identity(size)
}
