import torch
import torch.nn as nn

from layers_mtl import DiffractiveLayerRawNode


class mtl_model(nn.Module):

    def __init__(self):
        """
            compute depth for apply policy regularization
        """
        super(mtl_model, self).__init__()

    def compute_depth(self):
        """
            compute the depth for each Basic node
        """
        cur_dep = 0
        for module in self.modules():
            if self.check_type(module):
                module.depth = cur_dep
                cur_dep += 1


    def share_bottom_policy(self, share_num):
        count = 0
        for node in self.modules():
            if self.check_type(node):
                if count == share_num:
                    break
                else:
                    count += 1
                    for task in self.taskList:
                        node.policy[task] = nn.Parameter(torch.tensor([1., 0., 0.]))
                        node.policy[task].requires_grad = False
        return

    def max_node_depth(self):
        max_depth = 0
        for module in self.modules():
            if self.check_type(module):
                max_depth = max(module.depth, max_depth)
        return max_depth

    def check_type(self, module):
        return isinstance(module, DiffractiveLayerRawNode)

    def policy_reg(self, task, policy_idx=None, tau=5, scale=1):
        """

        Args:
            task: current stage
            policy_idx: index of policy
            tau:
            scale:

        Returns:
            regulate the policy
        """
        reg = torch.tensor(0)
        if policy_idx is None:
            ### Regularization for all policy
            for module in self.modules():
                if isinstance(module, DiffractiveLayerRawNode):
                    policy_task = module.policy[task]
                    possiblity = nn.functional.gumbel_softmax(policy_task, tau=tau, hard=False)

                    loss = torch.log(1 + torch.exp(scale * (possiblity[1] - possiblity[0]))) + \
                            torch.log(1 + torch.exp(scale * (possiblity[2] - possiblity[0]))) + \
                            torch.log(1 + torch.exp(scale * (possiblity[3] - possiblity[0]))) + \
                            torch.log(1 + torch.exp(scale * (possiblity[4] - possiblity[0])))

                    weight = (self.max_node_depth() - module.depth) / self.max_node_depth()
                    reg = reg + weight * loss
                    
        elif policy_idx < self.max_policy_idx():
            ### Regularization for current trained policy
            reg = self.current_policy_reg(policy_idx, task, tau, scale)
        return reg
