from pathlib import Path
import torch
import torch.nn as nn

from collections import OrderedDict

from src import modules
from src import modules_fw
from src import modules_fw_v2
from src import utils
from src import resources


TASK_CHANNEL_MAPPING = {
    'semseg': 21,
    'human_parts': 7,
    'sal': 1,
    'normals': 3,
    'edge': 1
}


MODEL_URL = 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth'

def is_father_str(str, str_list):
    for s in str_list:
        if s in str:
            return True
    return False

class MoblieNetV2(nn.Module):
    def __init__(self, tasks, pretrain=True):

        super().__init__()
        self.tasks = tasks
        self.n_tasks = len(tasks)
        self._shared_parameters = OrderedDict()
        self._task_specific_parameters = OrderedDict()

        self.stem = modules.ConvBNReLU(in_channels=3,
                                       out_channels=32,
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       activation='relu6')

        blocks = [
            modules.InvertedResidual(
                32, 16, 1, 1, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                16, 24, 2, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                24, 24, 1, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                24, 32, 2, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                32, 32, 1, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                32, 32, 1, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                32, 64, 2, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                64, 64, 1, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                64, 64, 1, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                64, 64, 1, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                64, 96, 1, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                96, 96, 1, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                96, 96, 1, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                96, 160, 2, 6, dilation=2, activation='relu6'),
            modules.InvertedResidual(
                160, 160, 1, 6, dilation=2, activation='relu6'),
            modules.InvertedResidual(
                160, 160, 1, 6, dilation=2, activation='relu6'),
            modules.InvertedResidual(
                160, 320, 1, 6, dilation=2, activation='relu6'),
            modules.RASPP(320, 128, activation='relu6', drop_rate=0.1)
        ]

        self.encoder = nn.Sequential(*blocks)

        self.decoder = nn.ModuleDict({
            task: modules.DeepLabV3PlusDecoder(
                in_channels_low=24,
                out_f_classifier=128,
                use_separable=True,
                activation='relu6',
                num_outputs=TASK_CHANNEL_MAPPING[task])
            for task in self.tasks})

        self._obtian_td_layers()
        self._initialize_weights()

        if pretrain:
            self._load_imagenet_weights()

        self.turn()
        # print('---')

    def turn(self, task_branches=[]):
        self._shared_parameters.clear()

        if self.n_tasks > 1:
            for idx, m in self.named_modules():
                if idx in task_branches:
                    m.set_n_tasks(n_tasks=self.n_tasks)

        task_branches = task_branches + self.td_layers

        # obtain the shared parameters
        for idx, m in self.named_modules():
            if '.m_list' in idx:
                idx = idx[:idx.index('.m_list')]

            if idx not in task_branches:
                members = m._parameters.items()
                memo = set()
                for k, v in members:
                    if v is None or v in memo:
                        continue
                    memo.add(v)
                    name = idx + ('.' if idx else '') + k
                    self._shared_parameters[name] = v
            else:
                members = m._parameters.items()
                memo = set()
                for k, v in members:
                    if v is None or v in memo:
                        continue
                    memo.add(v)
                    name = idx + ('.' if idx else '') + k
                    self._task_specific_parameters[name] = v

    def shared_parameters(self):
        return self._shared_parameters

    def model_size(self):
        param_size = 0
        for param in self.parameters():
            param_size += param.nelement() * param.element_size()

        size_all_mb = param_size / 1024 ** 2
        return size_all_mb

    def forward(self, x):
        input_shape = x.shape[-2:]
        out = self.stem(x)

        # out = {0: x}
        for i, layer in enumerate(self.encoder):
            out = layer(out)
            if i == 2:
                out_low = torch.clone(out)

        output = {task: self.decoder[task](out,
                                           out_low,
                                           input_shape)
                  for task in self.tasks}


        return output

    def _initialize_weights(self):
        # weight initialization
        for name, m in self.named_modules():
            if isinstance(m, nn.Conv2d):
                if 'logits' in name:
                    # initialize final prediction layer with fixed std
                    nn.init.normal_(m.weight, mean=0, std=0.01)
                else:
                    nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                if m.affine:
                    nn.init.ones_(m.weight)
                    nn.init.zeros_(m.bias)

    def _obtian_td_layers(self):
        task_dependent_modules_names = ['decoder']

        self.td_layers = []
        for idx, m in self.named_modules():
            members = m._parameters.items()
            memo = set()
            for k, v in members:
                if v is None or v in memo:
                    continue
                memo.add(v)
                name = idx + ('.' if idx else '') + k

                if is_father_str(name, task_dependent_modules_names):
                    self.td_layers.append(idx)

        self.td_layers = list(set(self.td_layers))

    def _load_imagenet_weights(self):
        # we are using pretrained weights from torchvision.models:
        source_state_dict = torch.hub.load_state_dict_from_url(
            MODEL_URL, progress=False)
        target_state_dict = self.state_dict()

        mapped_state_dict = {}
        for name_trg in target_state_dict:
            if name_trg.startswith('decoder') or name_trg.startswith('encoder.17'):
                continue  # can't load decoder and ASPP
            if name_trg.startswith('stem'):
                name_src = name_trg.replace('stem.op', 'features.0')
            # this is highly specific to the current naming
            elif name_trg.startswith('encoder'):
                parsed = name_trg.split('.')
                layer_nr = int(parsed[1])
                chain_nr = int(parsed[3])
                op_nr = int(parsed[5])
                parsed[0] = 'features'
                parsed[1] = str(layer_nr + 1)
                parsed[2] = 'conv'
                if chain_nr == 0 or (chain_nr == 1 and layer_nr != 0):
                    parsed[3] = str(chain_nr)
                    del parsed[4]
                else:
                    parsed[3] = str(chain_nr + op_nr)
                    del parsed[4:6]
                name_src = '.'.join(parsed)
            else:
                raise ValueError
            mapped_state_dict[name_trg] = source_state_dict[name_src]
        self.load_state_dict(mapped_state_dict, strict=False)

class MoblieNetV2_fw(nn.Module):
    def __init__(self, tasks, topK=0, branched='empty', pretrain=True):

        super().__init__()
        self.tasks = tasks
        self.n_tasks = len(tasks)
        self._shared_parameters = OrderedDict()
        self._task_specific_parameters = OrderedDict()

        self.stem = modules_fw.ConvBNReLU_fw(in_channels=3,
                                       out_channels=32,
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       activation='relu6')

        blocks = [
            modules_fw.InvertedResidual_fw(
                32, 16, 1, 1, dilation=1, activation='relu6'),
            modules_fw.InvertedResidual_fw(
                16, 24, 2, 6, dilation=1, activation='relu6'),
            modules_fw.InvertedResidual_fw(
                24, 24, 1, 6, dilation=1, activation='relu6'),
            modules_fw.InvertedResidual_fw(
                24, 32, 2, 6, dilation=1, activation='relu6'),
            modules_fw.InvertedResidual_fw(
                32, 32, 1, 6, dilation=1, activation='relu6'),
            modules_fw.InvertedResidual_fw(
                32, 32, 1, 6, dilation=1, activation='relu6'),
            modules_fw.InvertedResidual_fw(
                32, 64, 2, 6, dilation=1, activation='relu6'),
            modules_fw.InvertedResidual_fw(
                64, 64, 1, 6, dilation=1, activation='relu6'),
            modules_fw.InvertedResidual_fw(
                64, 64, 1, 6, dilation=1, activation='relu6'),
            modules_fw.InvertedResidual_fw(
                64, 64, 1, 6, dilation=1, activation='relu6'),
            modules_fw.InvertedResidual_fw(
                64, 96, 1, 6, dilation=1, activation='relu6'),
            modules_fw.InvertedResidual_fw(
                96, 96, 1, 6, dilation=1, activation='relu6'),
            modules_fw.InvertedResidual_fw(
                96, 96, 1, 6, dilation=1, activation='relu6'),
            modules_fw.InvertedResidual_fw(
                96, 160, 2, 6, dilation=2, activation='relu6'),
            modules_fw.InvertedResidual_fw(
                160, 160, 1, 6, dilation=2, activation='relu6'),
            modules_fw.InvertedResidual_fw(
                160, 160, 1, 6, dilation=2, activation='relu6'),
            modules_fw.InvertedResidual_fw(
                160, 320, 1, 6, dilation=2, activation='relu6'),
            modules_fw.RASPP_fw(320, 128, activation='relu6', drop_rate=0.1)
        ]

        self.encoder = nn.Sequential(*blocks)

        self.decoder = nn.ModuleDict({
            task: modules.DeepLabV3PlusDecoder(
                in_channels_low=24,
                out_f_classifier=128,
                use_separable=True,
                activation='relu6',
                num_outputs=TASK_CHANNEL_MAPPING[task])
            for task in self.tasks})

        if branched == 'empty':
            task_branches = []
        elif branched == 'branch':
            # task_branches = ['encoder.17.aspp_branch_2.1.op.0', 'encoder.17.aspp_branch_2.1.op.1', 'encoder.17.aspp_projection.op.1', 'encoder.16.chain.2.op.1',
            #                  'encoder.0.chain.1.op.1', 'encoder.0.chain.0.op.0', 'stem.op.1', 'encoder.16.chain.0.op.1',
            #                  'encoder.15.chain.2.op.1', 'encoder.2.chain.2.op.1', 'stem.op.0', 'encoder.0.chain.0.op.1',
            #                  'encoder.1.chain.2.op.1', 'encoder.16.chain.1.op.1', 'encoder.16.chain.2.op.0', 'encoder.2.chain.0.op.1',
            #                  'encoder.5.chain.2.op.1', 'encoder.1.chain.0.op.1', 'encoder.3.chain.2.op.1', 'encoder.14.chain.2.op.1',
            #                  'encoder.17.aspp_branch_1.op.1', 'encoder.4.chain.2.op.1', 'encoder.9.chain.2.op.1', 'encoder.17.aspp_projection.op.0',
            #                  'encoder.13.chain.2.op.1', 'encoder.12.chain.2.op.1', 'encoder.15.chain.1.op.1', 'encoder.8.chain.2.op.1',
            #                  'encoder.1.chain.1.op.1', 'encoder.15.chain.0.op.1', 'encoder.16.chain.0.op.0', 'encoder.7.chain.2.op.1',
            #                  'encoder.11.chain.2.op.1', 'encoder.10.chain.2.op.1', 'encoder.6.chain.2.op.1', 'encoder.2.chain.1.op.1',
            #                  'encoder.3.chain.0.op.1', 'encoder.14.chain.1.op.0', 'encoder.13.chain.0.op.1', 'encoder.0.chain.1.op.0']

            # task_branches = ['encoder.17.aspp_branch_2.1.op.0', 'encoder.17.aspp_branch_2.1.op.1',
            #                  'encoder.17.aspp_projection.op.1', 'encoder.16.chain.2.op.1',
            #                  'encoder.0.chain.1.op.1', 'encoder.0.chain.0.op.0', 'stem.op.1', 'encoder.16.chain.0.op.1',
            #                  'encoder.15.chain.2.op.1', 'encoder.2.chain.2.op.1', 'stem.op.0', 'encoder.0.chain.0.op.1',
            #                  'encoder.1.chain.2.op.1', 'encoder.16.chain.1.op.1', 'encoder.16.chain.2.op.0',
            #                  'encoder.2.chain.0.op.1',
            #                  'encoder.5.chain.2.op.1', 'encoder.1.chain.0.op.1', 'encoder.3.chain.2.op.1',
            #                  'encoder.14.chain.2.op.1',
            #                  'encoder.17.aspp_branch_1.op.1', 'encoder.4.chain.2.op.1', 'encoder.9.chain.2.op.1',
            #                  'encoder.17.aspp_projection.op.0',
            #                  'encoder.13.chain.2.op.1', 'encoder.12.chain.2.op.1', 'encoder.15.chain.1.op.1',
            #                  'encoder.8.chain.2.op.1',
            #                  'encoder.1.chain.1.op.1', 'encoder.15.chain.0.op.1', 'encoder.16.chain.0.op.0',
            #                  'encoder.7.chain.2.op.1',
            #                  'encoder.11.chain.2.op.1', 'encoder.10.chain.2.op.1', 'encoder.6.chain.2.op.1',
            #                  'encoder.2.chain.1.op.1',
            #                  'encoder.3.chain.0.op.1', 'encoder.14.chain.1.op.0', 'encoder.13.chain.0.op.1',
            #                  'encoder.0.chain.1.op.0']
            # for b in task_branches:
            #     for name, param in self.named_parameters():
            #         if b in name:
            #             param_size = param.nelement() * param.element_size()
            #             print(f'{b}: {param_size}')


            # lw_cos -0.1
            # task_branches = ['encoder.0.chain.0.op.0', 'encoder.0.chain.1.op.1', 'stem.op.1', 'encoder.0.chain.0.op.1', 'stem.op.0',
            #                  'encoder.2.chain.2.op.1', 'encoder.1.chain.2.op.1', 'encoder.1.chain.0.op.1', 'encoder.5.chain.2.op.1', 'encoder.3.chain.2.op.1',
            #                  'encoder.4.chain.2.op.1', 'encoder.2.chain.0.op.1', 'encoder.9.chain.2.op.1', 'encoder.3.chain.0.op.1', 'encoder.1.chain.1.op.1',
            #                  'encoder.8.chain.2.op.1', 'encoder.0.chain.1.op.0', 'encoder.7.chain.2.op.1', 'encoder.6.chain.2.op.1', 'encoder.2.chain.1.op.1',
            #                  'encoder.12.chain.2.op.1', 'encoder.3.chain.1.op.1', 'encoder.4.chain.0.op.1', 'encoder.2.chain.1.op.0', 'encoder.5.chain.0.op.1',
            #                  'encoder.1.chain.1.op.0', 'encoder.12.chain.1.op.1', 'encoder.11.chain.2.op.1', 'encoder.10.chain.2.op.1', 'encoder.6.chain.0.op.1',
            #                  'encoder.5.chain.1.op.1', 'encoder.4.chain.1.op.1', 'encoder.4.chain.1.op.0', 'encoder.6.chain.1.op.1', 'encoder.5.chain.1.op.0',
            #                  'encoder.15.chain.2.op.1', 'encoder.7.chain.1.op.0', 'encoder.11.chain.1.op.1', 'encoder.1.chain.0.op.0', 'encoder.7.chain.0.op.1',
            #                  'encoder.13.chain.2.op.1', 'encoder.12.chain.1.op.0', 'encoder.14.chain.2.op.1', 'encoder.9.chain.0.op.1', 'encoder.10.chain.0.op.1',
            #                  'encoder.8.chain.1.op.1', 'encoder.8.chain.0.op.1', 'encoder.14.chain.1.op.0', 'encoder.9.chain.1.op.1', 'encoder.12.chain.0.op.1',
            #                  'encoder.8.chain.1.op.0', 'encoder.11.chain.0.op.1', 'encoder.11.chain.1.op.0', 'encoder.1.chain.2.op.0', 'encoder.10.chain.1.op.0',
            #                  'encoder.9.chain.1.op.0', 'encoder.17.aspp_branch_2.1.op.1', 'encoder.7.chain.1.op.1', 'encoder.17.aspp_projection.op.1', 'encoder.10.chain.1.op.1',
            #                  'encoder.13.chain.0.op.1', 'encoder.3.chain.2.op.0', 'encoder.2.chain.2.op.0', 'encoder.17.aspp_branch_1.op.1', 'encoder.3.chain.0.op.0',
            #                  'encoder.15.chain.1.op.1', 'encoder.14.chain.1.op.1', 'encoder.3.chain.1.op.0', 'encoder.15.chain.1.op.0', 'encoder.6.chain.1.op.0',
            #                  'encoder.16.chain.2.op.1', 'encoder.14.chain.0.op.1', 'encoder.4.chain.0.op.0', 'encoder.13.chain.1.op.1', 'encoder.6.chain.2.op.0',
            #                  'encoder.15.chain.0.op.1', 'encoder.4.chain.2.op.0', 'encoder.6.chain.0.op.0', 'encoder.5.chain.2.op.0', 'encoder.13.chain.1.op.0',
            #                  'encoder.5.chain.0.op.0', 'encoder.2.chain.0.op.0', 'encoder.17.aspp_branch_2.1.op.0'
            #                  ]

            # lw_cos Adam -0.1
            task_branches = [
'encoder.0.chain.0.op.0', 'stem.op.1',
'stem.op.0',
'encoder.0.chain.1.op.1',
'encoder.2.chain.0.op.1',
'encoder.0.chain.0.op.1',
'encoder.2.chain.2.op.1',
'encoder.1.chain.2.op.1',
'encoder.1.chain.0.op.1',
'encoder.5.chain.2.op.1',
'encoder.4.chain.2.op.1',
'encoder.3.chain.2.op.1',
'encoder.3.chain.0.op.1',
'encoder.0.chain.1.op.0',
'encoder.1.chain.1.op.0',
'encoder.1.chain.1.op.1',
'encoder.9.chain.2.op.1',
'encoder.2.chain.1.op.0',
'encoder.2.chain.1.op.1',
'encoder.4.chain.0.op.1',
'encoder.8.chain.2.op.1',
'encoder.7.chain.2.op.1',
'encoder.6.chain.2.op.1',
'encoder.3.chain.1.op.1',
'encoder.5.chain.0.op.1',
'encoder.6.chain.0.op.1',
'encoder.12.chain.2.op.1',
'encoder.12.chain.1.op.1',
'encoder.4.chain.1.op.1',
'encoder.4.chain.1.op.0',
'encoder.5.chain.1.op.1',
'encoder.8.chain.0.op.1',
'encoder.10.chain.2.op.1',
'encoder.11.chain.2.op.1',
'encoder.1.chain.0.op.0',
'encoder.17.aspp_branch_2.1.op.0',
'encoder.5.chain.1.op.0',
'encoder.6.chain.1.op.1',
'encoder.7.chain.0.op.1',
'encoder.12.chain.1.op.0',
'encoder.7.chain.1.op.0',
'encoder.12.chain.0.op.1',
'encoder.9.chain.0.op.1',
'encoder.15.chain.2.op.1',
'encoder.11.chain.0.op.1',
'encoder.11.chain.1.op.1',
'encoder.13.chain.2.op.1',
'encoder.1.chain.2.op.0',
'encoder.8.chain.1.op.0',
'encoder.14.chain.2.op.1',
'encoder.14.chain.1.op.0',
'encoder.8.chain.1.op.1',
'encoder.3.chain.0.op.0',
'encoder.10.chain.0.op.1',
'encoder.11.chain.1.op.0',
'encoder.3.chain.2.op.0',
'encoder.9.chain.1.op.1',
'encoder.2.chain.2.op.0',
'encoder.9.chain.1.op.0',
'encoder.13.chain.0.op.1',
'encoder.17.aspp_branch_2.1.op.1',
'encoder.7.chain.1.op.1',
'encoder.3.chain.1.op.0',
'encoder.4.chain.0.op.0',
'encoder.10.chain.1.op.1',
'encoder.10.chain.1.op.0',
'encoder.17.aspp_projection.op.1',
'encoder.14.chain.1.op.1',
'encoder.15.chain.1.op.1',
'encoder.14.chain.0.op.1',
'encoder.15.chain.1.op.0',
'encoder.16.chain.2.op.1',
'encoder.4.chain.2.op.0',
'encoder.6.chain.1.op.0',
'encoder.5.chain.2.op.0',
'encoder.6.chain.0.op.0',
'encoder.6.chain.2.op.0',
'encoder.17.aspp_branch_1.op.1',
'encoder.13.chain.1.op.1',
'encoder.2.chain.0.op.0',
'encoder.5.chain.0.op.0',
'encoder.15.chain.0.op.1',
'encoder.16.chain.0.op.1',
'encoder.13.chain.1.op.0',
'encoder.9.chain.2.op.0',
'encoder.8.chain.2.op.0',
'encoder.10.chain.2.op.0',
'encoder.12.chain.2.op.0',
'encoder.11.chain.2.op.0',
'encoder.10.chain.0.op.0',
'encoder.7.chain.2.op.0',
'encoder.8.chain.0.op.0',
'encoder.16.chain.1.op.0',
'encoder.7.chain.0.op.0',
'encoder.12.chain.0.op.0',
'encoder.11.chain.0.op.0',
'encoder.9.chain.0.op.0',
'encoder.13.chain.0.op.0',
'encoder.16.chain.1.op.1',
]

            # lw_cos Adam cagrad 40 epoch -0.1
            # task_branches = [
            #     'encoder.0.chain.0.op.0',
            #     'encoder.0.chain.1.op.1',
            #     'stem.op.1',
            #     'stem.op.0',
            #     'encoder.2.chain.0.op.1',
            #     'encoder.0.chain.0.op.1',
            #     'encoder.2.chain.2.op.1',
            #     'encoder.1.chain.0.op.1',
            #     'encoder.1.chain.2.op.1',
            #     'encoder.4.chain.2.op.1',
            #     'encoder.5.chain.2.op.1',
            #     'encoder.3.chain.2.op.1',
            #     'encoder.0.chain.1.op.0',
            #     'encoder.3.chain.0.op.1',
            #     'encoder.1.chain.1.op.1',
            #     'encoder.1.chain.1.op.0',
            #     'encoder.9.chain.2.op.1',
            #     'encoder.2.chain.1.op.1',
            #     'encoder.4.chain.0.op.1',
            #     'encoder.2.chain.1.op.0',
            #     'encoder.8.chain.2.op.1',
            #     'encoder.3.chain.1.op.1',
            #     'encoder.7.chain.2.op.1',
            #     'encoder.6.chain.2.op.1',
            #     'encoder.6.chain.0.op.1',
            #     'encoder.5.chain.0.op.1',
            #     'encoder.4.chain.1.op.1',
            #     'encoder.12.chain.1.op.1',
            #     'encoder.5.chain.1.op.1',
            #     'encoder.12.chain.2.op.1',
            #     'encoder.1.chain.0.op.0',
            #     'encoder.4.chain.1.op.0',
            #     'encoder.10.chain.2.op.1',
            #     'encoder.11.chain.2.op.1',
            #     'encoder.8.chain.0.op.1',
            #     'encoder.5.chain.1.op.0',
            #     'encoder.7.chain.0.op.1',
            #     'encoder.6.chain.1.op.1',
            #     'encoder.12.chain.1.op.0',
            #     'encoder.1.chain.2.op.0',
            #     'encoder.7.chain.1.op.0',
            #     'encoder.17.aspp_branch_2.1.op.0',
            #     'encoder.2.chain.2.op.0',
            #     'encoder.9.chain.0.op.1',
            #     'encoder.11.chain.1.op.1',
            #     'encoder.15.chain.2.op.1',
            #     'encoder.12.chain.0.op.1',
            #     'encoder.11.chain.0.op.1',
            #     'encoder.11.chain.1.op.0',
            #     'encoder.8.chain.1.op.0',
            #     'encoder.8.chain.1.op.1',
            #     'encoder.3.chain.0.op.0',
            #     'encoder.14.chain.1.op.0',
            #     'encoder.9.chain.1.op.0',
            #     'encoder.3.chain.2.op.0',
            #     'encoder.9.chain.1.op.1',
            #     'encoder.13.chain.2.op.1',
            #     'encoder.10.chain.0.op.1',
            #     'encoder.14.chain.2.op.1',
            #     'encoder.3.chain.1.op.0',
            #     'encoder.4.chain.0.op.0',
            #     'encoder.7.chain.1.op.1',
            #     'encoder.10.chain.1.op.1',
            #     'encoder.10.chain.1.op.0',
            #     'encoder.2.chain.0.op.0',
            #     'encoder.13.chain.0.op.1',
            #     'encoder.17.aspp_branch_2.1.op.1',
            #     'encoder.14.chain.1.op.1',
            #     'encoder.4.chain.2.op.0',
            #     'encoder.15.chain.1.op.0',
            #     'encoder.5.chain.2.op.0',
            #     'encoder.6.chain.1.op.0',
            #     'encoder.15.chain.1.op.1',
            #     'encoder.6.chain.0.op.0',
            #     'encoder.14.chain.0.op.1',
            #     'encoder.5.chain.0.op.0',
            #     'encoder.16.chain.2.op.1',
            #     'encoder.6.chain.2.op.0',
            #     'encoder.13.chain.1.op.1',
            #     'encoder.15.chain.0.op.1',
            #     'encoder.17.aspp_branch_1.op.1',
            #     'encoder.17.aspp_projection.op.1',
            #     'encoder.13.chain.1.op.0',
            #     'encoder.9.chain.2.op.0',
            #     'encoder.16.chain.1.op.0',
            #     'encoder.8.chain.2.op.0',
            #     'encoder.8.chain.0.op.0',
            #     'encoder.7.chain.2.op.0',
            #     'encoder.7.chain.0.op.0',
            #     'encoder.10.chain.0.op.0',
            #     'encoder.10.chain.2.op.0',
            #     'encoder.11.chain.2.op.0',
            #     'encoder.12.chain.2.op.0',
            #     'encoder.11.chain.0.op.0',
            #     'encoder.16.chain.0.op.1',
            #     'encoder.12.chain.0.op.0',
            #     'encoder.9.chain.0.op.0',
            #     'encoder.13.chain.0.op.0',
            #     'encoder.13.chain.2.op.0',
            # ]

            # lw_cos Adam cagrad 40 epoch -0.05
            # task_branches = [
            #     'encoder.0.chain.0.op.0',
            #     'encoder.0.chain.1.op.1',
            #     'stem.op.0',
            #     'stem.op.1',
            #     'encoder.2.chain.0.op.1',
            #     'encoder.0.chain.0.op.1',
            #     'encoder.2.chain.2.op.1',
            #     'encoder.1.chain.0.op.1',
            #     'encoder.1.chain.2.op.1',
            #     'encoder.5.chain.2.op.1',
            #     'encoder.4.chain.2.op.1',
            #     'encoder.3.chain.2.op.1',
            #     'encoder.0.chain.1.op.0',
            #     'encoder.3.chain.0.op.1',
            #     'encoder.1.chain.1.op.1',
            #     'encoder.1.chain.1.op.0',
            #     'encoder.9.chain.2.op.1',
            #     'encoder.2.chain.1.op.1',
            #     'encoder.4.chain.0.op.1',
            #     'encoder.2.chain.1.op.0',
            #     'encoder.8.chain.2.op.1',
            #     'encoder.7.chain.2.op.1',
            #     'encoder.3.chain.1.op.1',
            #     'encoder.6.chain.2.op.1',
            #     'encoder.6.chain.0.op.1',
            #     'encoder.5.chain.0.op.1',
            #     'encoder.12.chain.2.op.1',
            #     'encoder.12.chain.1.op.1',
            #     'encoder.4.chain.1.op.1',
            #     'encoder.17.aspp_branch_2.1.op.0',
            #     'encoder.5.chain.1.op.1',
            #     'encoder.1.chain.0.op.0',
            #     'encoder.10.chain.2.op.1',
            #     'encoder.4.chain.1.op.0',
            #     'encoder.11.chain.2.op.1',
            #     'encoder.15.chain.2.op.1',
            #     'encoder.8.chain.0.op.1',
            #     'encoder.5.chain.1.op.0',
            #     'encoder.7.chain.0.op.1',
            #     'encoder.6.chain.1.op.1',
            #     'encoder.12.chain.1.op.0',
            #     'encoder.1.chain.2.op.0',
            #     'encoder.7.chain.1.op.0',
            #     'encoder.9.chain.0.op.1',
            #     'encoder.13.chain.2.op.1',
            #     'encoder.12.chain.0.op.1',
            #     'encoder.2.chain.2.op.0',
            #     'encoder.17.aspp_branch_2.1.op.1',
            #     'encoder.11.chain.1.op.1',
            #     'encoder.14.chain.2.op.1',
            #     'encoder.11.chain.0.op.1',
            #     'encoder.11.chain.1.op.0',
            #     'encoder.14.chain.1.op.0',
            #     'encoder.8.chain.1.op.0',
            #     'encoder.8.chain.1.op.1',
            #     'encoder.10.chain.0.op.1',
            #     'encoder.9.chain.1.op.0',
            #     'encoder.9.chain.1.op.1',
            #     'encoder.3.chain.0.op.0',
            #     'encoder.3.chain.2.op.0',
            #     'encoder.16.chain.2.op.1',
            #     'encoder.3.chain.1.op.0',
            #     'encoder.17.aspp_projection.op.1',
            #     'encoder.13.chain.0.op.1',
            #     'encoder.7.chain.1.op.1',
            #     'encoder.4.chain.0.op.0',
            #     'encoder.10.chain.1.op.1',
            #     'encoder.10.chain.1.op.0',
            #     'encoder.2.chain.0.op.0',
            #     'encoder.17.aspp_branch_1.op.1',
            #     'encoder.14.chain.1.op.1',
            #     'encoder.15.chain.1.op.1',
            #     'encoder.15.chain.1.op.0',
            #     'encoder.14.chain.0.op.1',
            #     'encoder.15.chain.0.op.1',
            #     'encoder.13.chain.1.op.1',
            #     'encoder.6.chain.1.op.0',
            #     'encoder.4.chain.2.op.0',
            #     'encoder.5.chain.2.op.0',
            #     'encoder.6.chain.0.op.0',
            #     'encoder.5.chain.0.op.0',
            #     'encoder.6.chain.2.op.0',
            #     'encoder.16.chain.0.op.1',
            #     'encoder.13.chain.1.op.0',
            #     'encoder.16.chain.1.op.0',
            #     'encoder.9.chain.2.op.0',
            #     'encoder.8.chain.2.op.0',
            #     'encoder.16.chain.1.op.1',
            #     'encoder.10.chain.2.op.0',
            #     'encoder.12.chain.2.op.0',
            #     'encoder.7.chain.2.op.0',
            #     'encoder.8.chain.0.op.0',
            #     'encoder.10.chain.0.op.0',
            #     'encoder.11.chain.2.op.0',
            #     'encoder.7.chain.0.op.0',
            #     'encoder.9.chain.0.op.0',
            #     'encoder.11.chain.0.op.0',
            #     'encoder.12.chain.0.op.0',
            #     'encoder.13.chain.0.op.0'
            # ]

            # lw_cos Adam cagrad 40 epoch -0.00
            # task_branches = [
            #     'encoder.17.aspp_branch_2.1.op.0',
            #     'encoder.17.aspp_branch_2.1.op.1',
            #     'encoder.0.chain.0.op.0',
            #     'encoder.16.chain.2.op.1',
            #     'encoder.17.aspp_projection.op.1',
            #     'stem.op.0',
            #     'encoder.0.chain.1.op.1',
            #     'stem.op.1',
            #     'encoder.2.chain.0.op.1',
            #     'encoder.15.chain.2.op.1',
            #     'encoder.0.chain.0.op.1',
            #     'encoder.17.aspp_projection.op.0',
            #     'encoder.16.chain.0.op.1',
            #     'encoder.16.chain.2.op.0',
            #     'encoder.2.chain.2.op.1',
            #     'encoder.1.chain.2.op.1',
            #     'encoder.1.chain.0.op.1',
            #     'encoder.17.aspp_branch_1.op.1',
            #     'encoder.5.chain.2.op.1',
            #     'encoder.14.chain.2.op.1',
            #     'encoder.16.chain.1.op.1',
            #     'encoder.4.chain.2.op.1',
            #     'encoder.3.chain.2.op.1',
            #     'encoder.15.chain.1.op.1',
            #     'encoder.0.chain.1.op.0',
            #     'encoder.15.chain.0.op.1',
            #     'encoder.9.chain.2.op.1',
            #     'encoder.12.chain.2.op.1',
            #     'encoder.13.chain.2.op.1',
            #     'encoder.1.chain.1.op.1',
            #     'encoder.3.chain.0.op.1',
            #     'encoder.1.chain.1.op.0',
            #     'encoder.16.chain.0.op.0',
            #     'encoder.14.chain.1.op.0',
            #     'encoder.11.chain.2.op.1',
            #     'encoder.14.chain.0.op.1',
            #     'encoder.8.chain.2.op.1',
            #     'encoder.10.chain.2.op.1',
            #     'encoder.15.chain.1.op.0',
            #     'encoder.14.chain.1.op.1',
            #     'encoder.15.chain.2.op.0',
            #     'encoder.2.chain.1.op.1',
            #     'encoder.2.chain.1.op.0',
            #     'encoder.7.chain.2.op.1',
            #     'encoder.12.chain.1.op.1',
            #     'encoder.6.chain.2.op.1',
            #     'encoder.13.chain.0.op.1',
            #     'encoder.3.chain.1.op.1',
            #     'encoder.4.chain.0.op.1',
            #     'encoder.12.chain.0.op.1',
            #     'encoder.5.chain.0.op.1',
            #     'encoder.6.chain.0.op.1',
            #     'encoder.13.chain.1.op.1',
            #     'encoder.12.chain.1.op.0',
            #     'encoder.17.aspp_branch_1.op.0',
            #     'encoder.15.chain.0.op.0',
            #     'encoder.1.chain.0.op.0',
            #     'encoder.8.chain.0.op.1',
            #     'encoder.16.chain.1.op.0',
            #     'encoder.4.chain.1.op.1',
            #     'encoder.11.chain.1.op.0',
            #     'encoder.9.chain.0.op.1',
            #     'encoder.11.chain.0.op.1',
            #     'encoder.5.chain.1.op.1',
            #     'encoder.7.chain.1.op.0',
            #     'encoder.14.chain.2.op.0',
            #     'encoder.10.chain.0.op.1',
            #     'encoder.7.chain.0.op.1',
            #     'encoder.4.chain.1.op.0',
            #     'encoder.6.chain.1.op.1',
            #     'encoder.1.chain.2.op.0',
            #     'encoder.11.chain.1.op.1',
            #     'encoder.5.chain.1.op.0',
            #     'encoder.14.chain.0.op.0',
            #     'encoder.13.chain.2.op.0',
            #     'encoder.10.chain.1.op.0',
            #     'encoder.9.chain.1.op.0',
            #     'encoder.2.chain.2.op.0',
            #     'encoder.8.chain.1.op.1',
            #     'encoder.8.chain.1.op.0',
            #     'encoder.10.chain.1.op.1',
            #     'encoder.9.chain.1.op.1',
            #     'encoder.3.chain.0.op.0',
            #     'encoder.13.chain.0.op.0',
            #     'encoder.7.chain.1.op.1',
            #     'encoder.2.chain.0.op.0',
            #     'encoder.3.chain.2.op.0',
            #     'encoder.3.chain.1.op.0',
            #     'encoder.13.chain.1.op.0',
            #     'encoder.12.chain.2.op.0',
            #     'encoder.4.chain.0.op.0',
            #     'encoder.11.chain.2.op.0',
            #     'encoder.6.chain.0.op.0',
            #     'encoder.6.chain.1.op.0',
            #     'encoder.5.chain.0.op.0',
            #     'encoder.5.chain.2.op.0',
            #     'encoder.6.chain.2.op.0',
            #     'encoder.4.chain.2.op.0',
            #     'encoder.10.chain.2.op.0'
            # ]
            
            # lw_cos Adam cagrad 40 epoch -0.02
            task_branches = [
                'encoder.0.chain.0.op.0',
                'stem.op.0',
                'encoder.0.chain.1.op.1',
                'stem.op.1',
                'encoder.2.chain.0.op.1',
                'encoder.0.chain.0.op.1',
                'encoder.2.chain.2.op.1',
                'encoder.1.chain.2.op.1',
                'encoder.1.chain.0.op.1',
                'encoder.17.aspp_branch_2.1.op.0',
                'encoder.5.chain.2.op.1',
                'encoder.4.chain.2.op.1',
                'encoder.3.chain.2.op.1',
                'encoder.0.chain.1.op.0',
                'encoder.3.chain.0.op.1',
                'encoder.1.chain.1.op.1',
                'encoder.1.chain.1.op.0',
                'encoder.9.chain.2.op.1',
                'encoder.17.aspp_branch_2.1.op.1',
                'encoder.2.chain.1.op.1',
                'encoder.15.chain.2.op.1',
                'encoder.8.chain.2.op.1',
                'encoder.2.chain.1.op.0',
                'encoder.4.chain.0.op.1',
                'encoder.12.chain.2.op.1',
                'encoder.7.chain.2.op.1',
                'encoder.6.chain.2.op.1',
                'encoder.3.chain.1.op.1',
                'encoder.16.chain.2.op.1',
                'encoder.12.chain.1.op.1',
                'encoder.6.chain.0.op.1',
                'encoder.17.aspp_projection.op.1',
                'encoder.5.chain.0.op.1',
                'encoder.10.chain.2.op.1',
                'encoder.11.chain.2.op.1',
                'encoder.4.chain.1.op.1',
                'encoder.1.chain.0.op.0',
                'encoder.5.chain.1.op.1',
                'encoder.14.chain.2.op.1',
                'encoder.13.chain.2.op.1',
                'encoder.4.chain.1.op.0',
                'encoder.8.chain.0.op.1',
                'encoder.12.chain.1.op.0',
                'encoder.14.chain.1.op.0',
                'encoder.12.chain.0.op.1',
                'encoder.7.chain.0.op.1',
                'encoder.6.chain.1.op.1',
                'encoder.5.chain.1.op.0',
                'encoder.7.chain.1.op.0',
                'encoder.1.chain.2.op.0',
                'encoder.9.chain.0.op.1',
                'encoder.17.aspp_branch_1.op.1',
                'encoder.11.chain.0.op.1',
                'encoder.11.chain.1.op.0',
                'encoder.11.chain.1.op.1',
                'encoder.2.chain.2.op.0',
                'encoder.10.chain.0.op.1',
                'encoder.13.chain.0.op.1',
                'encoder.15.chain.1.op.1',
                'encoder.15.chain.0.op.1',
                'encoder.16.chain.0.op.1',
                'encoder.8.chain.1.op.1',
                'encoder.8.chain.1.op.0',
                'encoder.14.chain.1.op.1',
                'encoder.9.chain.1.op.0',
                'encoder.9.chain.1.op.1',
                'encoder.15.chain.1.op.0',
                'encoder.14.chain.0.op.1',
                'encoder.3.chain.0.op.0',
                'encoder.3.chain.2.op.0',
                'encoder.3.chain.1.op.0',
                'encoder.10.chain.1.op.0',
                'encoder.7.chain.1.op.1',
                'encoder.10.chain.1.op.1',
                'encoder.13.chain.1.op.1',
                'encoder.2.chain.0.op.0',
                'encoder.4.chain.0.op.0',
                'encoder.16.chain.1.op.0',
                'encoder.16.chain.1.op.1',
                'encoder.6.chain.1.op.0',
                'encoder.4.chain.2.op.0',
                'encoder.13.chain.1.op.0',
                'encoder.5.chain.2.op.0',
                'encoder.6.chain.0.op.0',
                'encoder.5.chain.0.op.0',
                'encoder.6.chain.2.op.0',
                'encoder.16.chain.2.op.0',
                'encoder.9.chain.2.op.0',
                'encoder.12.chain.2.op.0',
                'encoder.10.chain.2.op.0',
                'encoder.15.chain.2.op.0',
                'encoder.11.chain.2.op.0',
                'encoder.8.chain.2.op.0',
                'encoder.10.chain.0.op.0',
                'encoder.13.chain.2.op.0',
                'encoder.8.chain.0.op.0',
                'encoder.7.chain.2.op.0',
                'encoder.12.chain.0.op.0',
                'encoder.7.chain.0.op.0'
            ]
            
            # task_branches = [b for b in task_branches if not is_father_str(b, ['stem.op.', 'encoder.0.', 'encoder.1.', 'encoder.2.'])]
            #
            task_branches = task_branches[:topK]

        self._obtian_td_layers()
        self.turn(task_branches)

        self._initialize_weights()

        if pretrain:
            self._load_imagenet_weights()

        # self.turn()
        # print('---')

    def turn(self, task_branches=[]):
        self._shared_parameters.clear()

        if self.n_tasks > 1:
            for idx, m in self.named_modules():
                if idx in task_branches:
                    m.set_n_tasks(n_tasks=self.n_tasks)

        task_branches = task_branches + self.td_layers

        # obtain the shared parameters
        for idx, m in self.named_modules():
            if '.m_list' in idx:
                idx = idx[:idx.index('.m_list')]

            if idx not in task_branches:
                members = m._parameters.items()
                memo = set()
                for k, v in members:
                    if v is None or v in memo:
                        continue
                    memo.add(v)
                    name = idx + ('.' if idx else '') + k
                    self._shared_parameters[name] = v
            else:
                members = m._parameters.items()
                memo = set()
                for k, v in members:
                    if v is None or v in memo:
                        continue
                    memo.add(v)
                    name = idx + ('.' if idx else '') + k
                    self._task_specific_parameters[name] = v

    def shared_parameters(self):
        return self._shared_parameters

    def model_size(self):
        param_size = 0
        for param in self.parameters():
            param_size += param.nelement() * param.element_size()

        size_all_mb = param_size / 1024 ** 2
        return size_all_mb

    def forward(self, x):
        input_shape = x.shape[-2:]
        out = self.stem(x)

        # out = {0: x}
        for i, layer in enumerate(self.encoder):
            out = layer(out)
            if i == 2:
                out_low = modules_fw.clone_fw(out)

        if len(out) == 1 and len(out_low) == 1:
            output = {task: self.decoder[task](out[0],
                                               out_low[0],
                                               input_shape)
                      for task in self.tasks}

        elif len(out) == len(self.tasks) and len(out_low) == len(self.tasks):
            output = {task: self.decoder[task](out[i],
                                               out_low[i],
                                               input_shape)
                      for i, task in enumerate(self.tasks)}

        elif len(out) == len(self.tasks) and len(out_low) == 1:
            output = {task: self.decoder[task](out[i],
                                               out_low[0],
                                               input_shape)
                      for i, task in enumerate(self.tasks)}
        else:
             raise ValueError('Error')

        return output

    def zero_grad_shared_modules(self):
        for name, p in self.shared_parameters().items():
            if p.grad is not None:
                if p.grad.grad_fn is not None:
                    p.grad.detach_()
                else:
                    p.grad.requires_grad_(False)
                p.grad.zero_()

    def _initialize_weights(self):
        # weight initialization
        for name, m in self.named_modules():
            if isinstance(m, nn.Conv2d):
                if 'logits' in name:
                    # initialize final prediction layer with fixed std
                    nn.init.normal_(m.weight, mean=0, std=0.01)
                else:
                    nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                if m.affine:
                    nn.init.ones_(m.weight)
                    nn.init.zeros_(m.bias)

    def _obtian_td_layers(self):
        task_dependent_modules_names = ['decoder']

        self.td_layers = []
        for idx, m in self.named_modules():
            members = m._parameters.items()
            memo = set()
            for k, v in members:
                if v is None or v in memo:
                    continue
                memo.add(v)
                name = idx + ('.' if idx else '') + k

                if is_father_str(name, task_dependent_modules_names):
                    self.td_layers.append(idx)

        self.td_layers = list(set(self.td_layers))

    def _load_imagenet_weights(self):
        # we are using pretrained weights from torchvision.models:
        source_state_dict = torch.hub.load_state_dict_from_url(
            MODEL_URL, progress=False)
        target_state_dict = self.state_dict()

        mapped_state_dict = {}
        for name_trg in target_state_dict:
            if name_trg.startswith('decoder') or name_trg.startswith('encoder.17'):
                continue  # can't load decoder and ASPP
            if name_trg.startswith('stem'):
                parsed = name_trg.split('.')
                pos = parsed.index('m_list')
                del parsed[pos: pos+2]
                parsed = '.'.join(parsed)
                name_src = parsed.replace('stem.op', 'features.0')
            # this is highly specific to the current naming
            elif name_trg.startswith('encoder'):
                parsed = name_trg.split('.')
                pos = parsed.index('m_list')
                del parsed[pos: pos + 2]
                layer_nr = int(parsed[1])
                chain_nr = int(parsed[3])
                op_nr = int(parsed[5])
                parsed[0] = 'features'
                parsed[1] = str(layer_nr + 1)
                parsed[2] = 'conv'
                if chain_nr == 0 or (chain_nr == 1 and layer_nr != 0):
                    parsed[3] = str(chain_nr)
                    del parsed[4]
                else:
                    parsed[3] = str(chain_nr + op_nr)
                    del parsed[4:6]
                name_src = '.'.join(parsed)
            else:
                raise ValueError
            mapped_state_dict[name_trg] = source_state_dict[name_src]
        self.load_state_dict(mapped_state_dict, strict=False)

class BranchMobileNetV2_fw_v2(nn.Module):
    """
    Branched multi-task network based on a MobileNetV2 backbone, R-ASPP module and
    DeepLabv3+ head.
    branch_config: Array (or nested list) of shape (n_layers X n_tasks) expressing which
    blocks are sampled for each task at each layer. Example (n_layers=6, n_tasks=4):
    branch_config = [
        [0, 0, 0, 0],
        [2, 2, 2, 2],
        [1, 1, 2, 2],
        [0, 0, 2, 3],
        [0, 1, 2, 3],
        [1, 1, 3, 3]
    ]
    This array determines the branching configuration.
    """

    def __init__(self, tasks, branch_config, branched='empty', topK=None, pretrain=True):

        super(BranchMobileNetV2_fw_v2, self).__init__()
        self.tasks = tasks
        self.n_tasks = len(tasks)
        self.default_mapping = [frozenset([i for i in range(self.n_tasks)])]

        self.branch_config = branch_config

        self._shared_parameters = OrderedDict()
        self._task_specific_parameters = OrderedDict()

        self.stem = modules_fw_v2.ConvBNReLU_fw_v2(in_channels=3,
                                             out_channels=32,
                                             kernel_size=3,
                                             stride=2,
                                             padding=1,
                                             activation='relu6',
                                             mapping=self.default_mapping)

        mappings = self._get_branch_mappings()

        # blocks_type = []
        # self.shared_flag = -1
        #
        # for i, map_list in enumerate(mappings):
        #     if len(map_list) == 1:
        #         blocks_type.append(modules_fw.InvertedResidual_fw)
        #         if self.shared_flag >= 0:
        #             raise ValueError('Error: the structure is wrong.')
        #     else:
        #         blocks_type.append(modules.InvertedResidual)
        #         if self.shared_flag < 0:
        #             self.shared_flag = i

        blocks = [
            modules_fw_v2.InvertedResidual_fw_v2(
                32, 16, 1, 1, dilation=1, activation='relu6', mapping=mappings[0]),
            modules_fw_v2.InvertedResidual_fw_v2(
                16, 24, 2, 6, dilation=1, activation='relu6', mapping=mappings[1]),
            modules_fw_v2.InvertedResidual_fw_v2(
                24, 24, 1, 6, dilation=1, activation='relu6', mapping=mappings[2]),
            modules_fw_v2.InvertedResidual_fw_v2(
                24, 32, 2, 6, dilation=1, activation='relu6', mapping=mappings[3]),
            modules_fw_v2.InvertedResidual_fw_v2(
                32, 32, 1, 6, dilation=1, activation='relu6', mapping=mappings[4]),
            modules_fw_v2.InvertedResidual_fw_v2(
                32, 32, 1, 6, dilation=1, activation='relu6', mapping=mappings[5]),
            modules_fw_v2.InvertedResidual_fw_v2(
                32, 64, 2, 6, dilation=1, activation='relu6', mapping=mappings[6]),
            modules_fw_v2.InvertedResidual_fw_v2(
                64, 64, 1, 6, dilation=1, activation='relu6', mapping=mappings[7]),
            modules_fw_v2.InvertedResidual_fw_v2(
                64, 64, 1, 6, dilation=1, activation='relu6', mapping=mappings[8]),
            modules_fw_v2.InvertedResidual_fw_v2(
                64, 64, 1, 6, dilation=1, activation='relu6', mapping=mappings[9]),
            modules_fw_v2.InvertedResidual_fw_v2(
                64, 96, 1, 6, dilation=1, activation='relu6', mapping=mappings[10]),
            modules_fw_v2.InvertedResidual_fw_v2(
                96, 96, 1, 6, dilation=1, activation='relu6', mapping=mappings[11]),
            modules_fw_v2.InvertedResidual_fw_v2(
                96, 96, 1, 6, dilation=1, activation='relu6', mapping=mappings[12]),
            modules_fw_v2.InvertedResidual_fw_v2(
                96, 160, 2, 6, dilation=2, activation='relu6', mapping=mappings[13]),
            modules_fw_v2.InvertedResidual_fw_v2(
                160, 160, 1, 6, dilation=2, activation='relu6', mapping=mappings[14]),
            modules_fw_v2.InvertedResidual_fw_v2(
                160, 160, 1, 6, dilation=2, activation='relu6', mapping=mappings[15]),
            modules_fw_v2.InvertedResidual_fw_v2(
                160, 320, 1, 6, dilation=2, activation='relu6', mapping=mappings[16]),
            modules_fw_v2.RASPP_fw_v2(320, 128, activation='relu6', drop_rate=0.1, mapping=mappings[17])
        ]

        self.encoder = nn.Sequential(*blocks)

        self.decoder = nn.ModuleDict({
            task: modules.DeepLabV3PlusDecoder(
                in_channels_low=24,
                out_f_classifier=128,
                use_separable=True,
                activation='relu6',
                num_outputs=TASK_CHANNEL_MAPPING[task])
            for task in self.tasks})

        self.x_decoder_mapping = {task: i for i, task in enumerate(self.tasks)}

        if branched == 'empty':
            task_branches = []
        elif branched == 'branch':
            # lw_cos Adam -0.1
            # task_branches = ['shared.0.chain.0.op.0', 'shared.0.chain.1.op.1', 'stem.op.1', 'stem.op.0',
            #                  'shared.2.chain.0.op.1',
            #                  'shared.0.chain.0.op.1', 'shared.2.chain.2.op.1', 'shared.1.chain.0.op.1',
            #                  'shared.1.chain.2.op.1', 'shared.5.chain.2.op.1',
            #                  'shared.3.chain.0.op.1', 'shared.4.chain.2.op.1', 'shared.3.chain.2.op.1',
            #                  'shared.0.chain.1.op.0', 'shared.1.chain.1.op.1',
            #                  'shared.1.chain.1.op.0', 'shared.2.chain.1.op.1', 'shared.4.chain.0.op.1',
            #                  'shared.2.chain.1.op.0', 'shared.3.chain.1.op.1',
            #                  'shared.5.chain.0.op.1', 'shared.4.chain.1.op.1', 'shared.4.chain.1.op.0',
            #                  'shared.5.chain.1.op.1', 'shared.1.chain.0.op.0']
            # # lw_cos Adam 0.0
#             task_branches = ['encoder.0.chain.0.op.0', 'stem.op.1',
# 'stem.op.0',
# 'encoder.0.chain.1.op.1',
# 'encoder.2.chain.0.op.1',
# 'encoder.0.chain.0.op.1',
# 'encoder.2.chain.2.op.1',
# 'encoder.1.chain.2.op.1',
# 'encoder.1.chain.0.op.1',
# 'encoder.5.chain.2.op.1',
# 'encoder.4.chain.2.op.1',
# 'encoder.3.chain.2.op.1',
# 'encoder.3.chain.0.op.1',
# 'encoder.0.chain.1.op.0',
# 'encoder.1.chain.1.op.0',
# 'encoder.1.chain.1.op.1',
# 'encoder.9.chain.2.op.1',
# 'encoder.2.chain.1.op.0',
# 'encoder.2.chain.1.op.1',
# 'encoder.4.chain.0.op.1',
# 'encoder.8.chain.2.op.1',
# 'encoder.7.chain.2.op.1',
# 'encoder.6.chain.2.op.1',
# 'encoder.3.chain.1.op.1',
# 'encoder.5.chain.0.op.1',
# 'encoder.6.chain.0.op.1',
# 'encoder.12.chain.2.op.1',
# 'encoder.12.chain.1.op.1',
# 'encoder.4.chain.1.op.1',
# 'encoder.4.chain.1.op.0',
# 'encoder.5.chain.1.op.1',
# 'encoder.8.chain.0.op.1',
# 'encoder.10.chain.2.op.1',
# 'encoder.11.chain.2.op.1',
# 'encoder.1.chain.0.op.0',
# 'encoder.17.aspp_branch_2.1.op.0',
# 'encoder.5.chain.1.op.0',
# 'encoder.6.chain.1.op.1',
# 'encoder.7.chain.0.op.1',
# 'encoder.12.chain.1.op.0',
# 'encoder.7.chain.1.op.0',
# 'encoder.12.chain.0.op.1',
# 'encoder.9.chain.0.op.1',
# 'encoder.15.chain.2.op.1',
# 'encoder.11.chain.0.op.1',
# 'encoder.11.chain.1.op.1',
# 'encoder.13.chain.2.op.1',
# 'encoder.1.chain.2.op.0',
# 'encoder.8.chain.1.op.0',
# 'encoder.14.chain.2.op.1',
# 'encoder.14.chain.1.op.0',
# 'encoder.8.chain.1.op.1',
# 'encoder.3.chain.0.op.0',
# 'encoder.10.chain.0.op.1',
# 'encoder.11.chain.1.op.0',
# 'encoder.3.chain.2.op.0',
# 'encoder.9.chain.1.op.1',
# 'encoder.2.chain.2.op.0',
# 'encoder.9.chain.1.op.0',
# 'encoder.13.chain.0.op.1',
# 'encoder.17.aspp_branch_2.1.op.1',
# 'encoder.7.chain.1.op.1',
# 'encoder.3.chain.1.op.0',
# 'encoder.4.chain.0.op.0',
# 'encoder.10.chain.1.op.1',
# 'encoder.10.chain.1.op.0',
# 'encoder.17.aspp_projection.op.1',
# 'encoder.14.chain.1.op.1',
# 'encoder.15.chain.1.op.1',
# 'encoder.14.chain.0.op.1',
# 'encoder.15.chain.1.op.0',
# 'encoder.16.chain.2.op.1',
# 'encoder.4.chain.2.op.0',
# 'encoder.6.chain.1.op.0',
# 'encoder.5.chain.2.op.0',
# 'encoder.6.chain.0.op.0',
# 'encoder.6.chain.2.op.0',
# 'encoder.17.aspp_branch_1.op.1',
# 'encoder.13.chain.1.op.1',
# 'encoder.2.chain.0.op.0',
# 'encoder.5.chain.0.op.0',
# 'encoder.15.chain.0.op.1',
# 'encoder.16.chain.0.op.1',
# 'encoder.13.chain.1.op.0',
# 'encoder.9.chain.2.op.0',
# 'encoder.8.chain.2.op.0',
# 'encoder.10.chain.2.op.0',
# 'encoder.12.chain.2.op.0',
# 'encoder.11.chain.2.op.0',
# 'encoder.10.chain.0.op.0',
# 'encoder.7.chain.2.op.0',
# 'encoder.8.chain.0.op.0',
# 'encoder.16.chain.1.op.0',
# 'encoder.7.chain.0.op.0',
# 'encoder.12.chain.0.op.0',
# 'encoder.11.chain.0.op.0',
# 'encoder.9.chain.0.op.0',
# 'encoder.13.chain.0.op.0',
# 'encoder.16.chain.1.op.1',
# ]

            # lw_cos Adam cagrad 40 epoch -0.02
            task_branches = [
                'encoder.0.chain.0.op.0',
                'stem.op.0',
                'encoder.0.chain.1.op.1',
                'stem.op.1',
                'encoder.2.chain.0.op.1',
                'encoder.0.chain.0.op.1',
                'encoder.2.chain.2.op.1',
                'encoder.1.chain.2.op.1',
                'encoder.1.chain.0.op.1',
                'encoder.17.aspp_branch_2.1.op.0',
                'encoder.5.chain.2.op.1',
                'encoder.4.chain.2.op.1',
                'encoder.3.chain.2.op.1',
                'encoder.0.chain.1.op.0',
                'encoder.3.chain.0.op.1',
                'encoder.1.chain.1.op.1',
                'encoder.1.chain.1.op.0',
                'encoder.9.chain.2.op.1',
                'encoder.17.aspp_branch_2.1.op.1',
                'encoder.2.chain.1.op.1',
                'encoder.15.chain.2.op.1',
                'encoder.8.chain.2.op.1',
                'encoder.2.chain.1.op.0',
                'encoder.4.chain.0.op.1',
                'encoder.12.chain.2.op.1',
                'encoder.7.chain.2.op.1',
                'encoder.6.chain.2.op.1',
                'encoder.3.chain.1.op.1',
                'encoder.16.chain.2.op.1',
                'encoder.12.chain.1.op.1',
                'encoder.6.chain.0.op.1',
                'encoder.17.aspp_projection.op.1',
                'encoder.5.chain.0.op.1',
                'encoder.10.chain.2.op.1',
                'encoder.11.chain.2.op.1',
                'encoder.4.chain.1.op.1',
                'encoder.1.chain.0.op.0',
                'encoder.5.chain.1.op.1',
                'encoder.14.chain.2.op.1',
                'encoder.13.chain.2.op.1',
                'encoder.4.chain.1.op.0',
                'encoder.8.chain.0.op.1',
                'encoder.12.chain.1.op.0',
                'encoder.14.chain.1.op.0',
                'encoder.12.chain.0.op.1',
                'encoder.7.chain.0.op.1',
                'encoder.6.chain.1.op.1',
                'encoder.5.chain.1.op.0',
                'encoder.7.chain.1.op.0',
                'encoder.1.chain.2.op.0',
                'encoder.9.chain.0.op.1',
                'encoder.17.aspp_branch_1.op.1',
                'encoder.11.chain.0.op.1',
                'encoder.11.chain.1.op.0',
                'encoder.11.chain.1.op.1',
                'encoder.2.chain.2.op.0',
                'encoder.10.chain.0.op.1',
                'encoder.13.chain.0.op.1',
                'encoder.15.chain.1.op.1',
                'encoder.15.chain.0.op.1',
                'encoder.16.chain.0.op.1',
                'encoder.8.chain.1.op.1',
                'encoder.8.chain.1.op.0',
                'encoder.14.chain.1.op.1',
                'encoder.9.chain.1.op.0',
                'encoder.9.chain.1.op.1',
                'encoder.15.chain.1.op.0',
                'encoder.14.chain.0.op.1',
                'encoder.3.chain.0.op.0',
                'encoder.3.chain.2.op.0',
                'encoder.3.chain.1.op.0',
                'encoder.10.chain.1.op.0',
                'encoder.7.chain.1.op.1',
                'encoder.10.chain.1.op.1',
                'encoder.13.chain.1.op.1',
                'encoder.2.chain.0.op.0',
                'encoder.4.chain.0.op.0',
                'encoder.16.chain.1.op.0',
                'encoder.16.chain.1.op.1',
                'encoder.6.chain.1.op.0',
                'encoder.4.chain.2.op.0',
                'encoder.13.chain.1.op.0',
                'encoder.5.chain.2.op.0',
                'encoder.6.chain.0.op.0',
                'encoder.5.chain.0.op.0',
                'encoder.6.chain.2.op.0',
                'encoder.16.chain.2.op.0',
                'encoder.9.chain.2.op.0',
                'encoder.12.chain.2.op.0',
                'encoder.10.chain.2.op.0',
                'encoder.15.chain.2.op.0',
                'encoder.11.chain.2.op.0',
                'encoder.8.chain.2.op.0',
                'encoder.10.chain.0.op.0',
                'encoder.13.chain.2.op.0',
                'encoder.8.chain.0.op.0',
                'encoder.7.chain.2.op.0',
                'encoder.12.chain.0.op.0',
                'encoder.7.chain.0.op.0'
            ]

            task_branches = task_branches[:topK]

        self._obtian_td_layers()
        self.turn(task_branches)

        self._initialize_weights()

        if pretrain:
            self._load_imagenet_weights()

    def turn(self, task_branches=[]):
        self._shared_parameters.clear()

        if self.n_tasks > 1:
            for idx, m in self.named_modules():
                if idx in task_branches:
                    m.set_mapping([frozenset([i]) for i in range(self.n_tasks)])

        task_branches = task_branches + self.td_layers

        # obtain the shared parameters
        for idx, m in self.named_modules():
            if '.m_list' in idx:
                idx = idx[:idx.index('.m_list')]

            if idx not in task_branches:
                members = m._parameters.items()
                memo = set()
                for k, v in members:
                    if v is None or v in memo:
                        continue
                    memo.add(v)
                    name = idx + ('.' if idx else '') + k
                    self._shared_parameters[name] = v
            else:
                members = m._parameters.items()
                memo = set()
                for k, v in members:
                    if v is None or v in memo:
                        continue
                    memo.add(v)
                    name = idx + ('.' if idx else '') + k
                    self._task_specific_parameters[name] = v

    def shared_parameters(self):
        return self._shared_parameters

    def zero_grad_shared_modules(self):
        for name, p in self.shared_parameters().items():
            if p.grad is not None:
                if p.grad.grad_fn is not None:
                    p.grad.detach_()
                else:
                    p.grad.requires_grad_(False)
                p.grad.zero_()

    def _obtian_td_layers(self):
        task_dependent_modules_names = ['decoder', 'encoder']

        self.td_layers = []
        for idx, m in self.named_modules():
            members = m._parameters.items()
            memo = set()
            for k, v in members:
                if v is None or v in memo:
                    continue
                memo.add(v)
                name = idx + ('.' if idx else '') + k

                if is_father_str(name, task_dependent_modules_names):
                    self.td_layers.append(idx)

        self.td_layers = list(set(self.td_layers))

    def model_size(self):
        param_size = 0
        for param in self.parameters():
            param_size += param.nelement() * param.element_size()

        size_all_mb = param_size / 1024 ** 2
        return size_all_mb

    def _get_branch_mappings(self):
        """
        Calculates branch mappings from the branch_config. `mappings` is a list of dicts mapping
        the index of an input branch to indices of output branches. For example:
        mappings = [
            {0: [0]},
            {0: [0, 1]},
            {0: [0, 2], 1: [1, 3]},
            {0: [1], 1: [2], 2: [3], 3: [0]}
        ]
        """

        def get_partition(layer_config):
            partition = []
            for t in range(len(self.tasks)):
                s = {i for i, x in enumerate(layer_config) if x == t}
                if not len(s) == 0:
                    partition.append(s)
            return partition

        def make_refinement(partition, ancestor):
            """ make `partition` a refinement of `ancestor` """
            refinement = []
            for part_1 in partition:
                for part_2 in ancestor:
                    inter = part_1.intersection(part_2)
                    if not len(inter) == 0:
                        refinement.append(inter)
            return refinement

        task_grouping = [set(range(len(self.tasks))), ]

        refinement_all = []
        for layer_idx, layer_config in enumerate(self.branch_config):
            partition = get_partition(layer_config)
            refinement = make_refinement(partition, task_grouping)
            refinement_all.append(refinement)

        return refinement_all

    def _get_output_key_with_task_id(self, keys_list, id):
        out = []
        for keys in keys_list:
            if id in keys:
                out.append(keys)
        assert len(out) == 1
        return out[0]

    def forward(self, x):
        input_shape = x.shape[-2:]
        x = {self.default_mapping[0]: x}
        x = self.stem(x)

        for i, layer in enumerate(self.encoder):
            x = layer(x)
            if i == 2:
                out_low = modules_fw_v2.clone_fw_v2(x)

        output = {}

        for task in self.tasks:
            id = self.x_decoder_mapping[task]
            if len(x) < self.n_tasks:
                x_out = self._get_output_key_with_task_id(x.keys(), id)
                x_low_out = self._get_output_key_with_task_id(out_low.keys(), id)
                output[task] = self.decoder[task](x[x_out], out_low[x_low_out], input_shape)
            else:
                output[task] = self.decoder[task](x[frozenset([id])], out_low[frozenset([id])], input_shape)

        return output

    def _initialize_weights(self):
        # weight initialization
        for name, m in self.named_modules():
            if isinstance(m, nn.Conv2d):
                if 'logits' in name:
                    # initialize final prediction layer with fixed std
                    nn.init.normal_(m.weight, mean=0, std=0.01)
                else:
                    nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                if m.affine:
                    nn.init.ones_(m.weight)
                    nn.init.zeros_(m.bias)

    def _load_imagenet_weights(self):
        # we are using pretrained weights from torchvision.models:
        source_state_dict = torch.hub.load_state_dict_from_url(
            MODEL_URL, progress=False)
        target_state_dict = self.state_dict()

        mapped_state_dict = {}
        for name_trg in target_state_dict:
            if name_trg.startswith('decoder') or name_trg.startswith(f'encoder.17'):
                continue  # can't load decoder and ASPP
            if name_trg.startswith('stem'):
                parsed = name_trg.split('.')
                pos = parsed.index('m_list')
                del parsed[pos: pos + 2]
                parsed = '.'.join(parsed)
                name_src = parsed.replace('stem.op', 'features.0')

            elif name_trg.startswith('encoder'):
                parsed = name_trg.split('.')

                pos = parsed.index('m_list')
                del parsed[pos: pos + 2]

                layer_nr = int(parsed[1])
                chain_nr = int(parsed[3])
                op_nr = int(parsed[5])

                parsed[0] = 'features'
                parsed[1] = str(layer_nr + 1)
                parsed[2] = 'conv'

                if chain_nr == 0 or (chain_nr == 1 and layer_nr != 0):
                    parsed[3] = str(chain_nr)
                    del parsed[4]
                else:
                    parsed[3] = str(chain_nr + op_nr)
                    del parsed[4:6]
                name_src = '.'.join(parsed)
            else:
                raise ValueError
            mapped_state_dict[name_trg] = source_state_dict[name_src]
        self.load_state_dict(mapped_state_dict, strict=False)

class BranchMobileNetV2_fw(nn.Module):
    """
    Branched multi-task network based on a MobileNetV2 backbone, R-ASPP module and
    DeepLabv3+ head.
    branch_config: Array (or nested list) of shape (n_layers X n_tasks) expressing which
    blocks are sampled for each task at each layer. Example (n_layers=6, n_tasks=4):
    branch_config = [
        [0, 0, 0, 0],
        [2, 2, 2, 2],
        [1, 1, 2, 2],
        [0, 0, 2, 3],
        [0, 1, 2, 3],
        [1, 1, 3, 3]
    ]
    This array determines the branching configuration.
    """

    def __init__(self, tasks, branch_config, branched='empty', topK=None, pretrain=True):

        super().__init__()
        self.tasks = tasks
        self.n_tasks = len(tasks)
        self.branch_config = branch_config

        self._shared_parameters = OrderedDict()
        self._task_specific_parameters = OrderedDict()

        self.stem = modules_fw.ConvBNReLU_fw(in_channels=3,
                                       out_channels=32,
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       activation='relu6')

        mappings = self._get_branch_mappings()

        blocks_type = []
        self.shared_flag = -1

        for i, map_list in enumerate(mappings):
            if len(map_list) == 1:
                blocks_type.append(modules_fw.InvertedResidual_fw)
                if self.shared_flag >= 0:
                    raise ValueError('Error: the structure is wrong.')
            else:
                blocks_type.append(modules.InvertedResidual)
                if self.shared_flag < 0:
                    self.shared_flag = i

        blocks = [
            blocks_type[0](
                32, 16, 1, 1, dilation=1, activation='relu6'),
            blocks_type[1](
                16, 24, 2, 6, dilation=1, activation='relu6'),
            blocks_type[2](
                24, 24, 1, 6, dilation=1, activation='relu6'),
            blocks_type[3](
                24, 32, 2, 6, dilation=1, activation='relu6'),
            blocks_type[4](
                32, 32, 1, 6, dilation=1, activation='relu6'),
            blocks_type[5](
                32, 32, 1, 6, dilation=1, activation='relu6'),
            blocks_type[6](
                32, 64, 2, 6, dilation=1, activation='relu6'),
            blocks_type[7](
                64, 64, 1, 6, dilation=1, activation='relu6'),
            blocks_type[8](
                64, 64, 1, 6, dilation=1, activation='relu6'),
            blocks_type[9](
                64, 64, 1, 6, dilation=1, activation='relu6'),
            blocks_type[10](
                64, 96, 1, 6, dilation=1, activation='relu6'),
            blocks_type[11](
                96, 96, 1, 6, dilation=1, activation='relu6'),
            blocks_type[12](
                96, 96, 1, 6, dilation=1, activation='relu6'),
            blocks_type[13](
                96, 160, 2, 6, dilation=2, activation='relu6'),
            blocks_type[14](
                160, 160, 1, 6, dilation=2, activation='relu6'),
            blocks_type[15](
                160, 160, 1, 6, dilation=2, activation='relu6'),
            blocks_type[16](
                160, 320, 1, 6, dilation=2, activation='relu6'),
            modules.RASPP(320, 128, activation='relu6', drop_rate=0.1)
        ]

        shared = [blocks[i] for i in range(self.shared_flag)]
        encoder_blocks = [blocks[i] for i in range(self.shared_flag, len(blocks))]
        encoder_mappings = [mappings[i] for i in range(self.shared_flag, len(blocks))]

        self.shared = nn.Sequential(*shared)
        self.encoder = nn.Sequential(*[modules.BranchedLayer_v2(
            bl, ma, n_tasks=self.n_tasks) for bl, ma in zip(encoder_blocks, encoder_mappings)])

        self.decoder = nn.ModuleDict({
            task: modules.DeepLabV3PlusDecoder(
                in_channels_low=24,
                out_f_classifier=128,
                use_separable=True,
                activation='relu6',
                num_outputs=TASK_CHANNEL_MAPPING[task])
            for task in self.tasks})

        self.x_decoder_mapping = {task: i for i, task in enumerate(self.tasks)}

        if branched == 'empty':
            task_branches = []
        elif branched == 'branch':
            # lw_cos Adam -0.1
            # task_branches = ['shared.0.chain.0.op.0', 'shared.0.chain.1.op.1', 'stem.op.1', 'stem.op.0', 'shared.2.chain.0.op.1',
            #                  'shared.0.chain.0.op.1', 'shared.2.chain.2.op.1', 'shared.1.chain.0.op.1', 'shared.1.chain.2.op.1', 'shared.5.chain.2.op.1',
            #                  'shared.3.chain.0.op.1', 'shared.4.chain.2.op.1', 'shared.3.chain.2.op.1', 'shared.0.chain.1.op.0', 'shared.1.chain.1.op.1',
            #                  'shared.1.chain.1.op.0', 'shared.2.chain.1.op.1', 'shared.4.chain.0.op.1', 'shared.2.chain.1.op.0', 'shared.3.chain.1.op.1',
            #                  'shared.5.chain.0.op.1', 'shared.4.chain.1.op.1', 'shared.4.chain.1.op.0', 'shared.5.chain.1.op.1', 'shared.1.chain.0.op.0']
            # lw_cos Adam 0.0
            task_branches = ['shared.0.chain.0.op.0',
                            'shared.0.chain.1.op.1',
                            'stem.op.0',
                            'stem.op.1',
                            'shared.2.chain.0.op.1',
                            'shared.0.chain.0.op.1',
                            'shared.2.chain.2.op.1',
                            'shared.5.chain.2.op.1',
                            'shared.1.chain.2.op.1',
                            'shared.1.chain.0.op.1',
                            'shared.4.chain.2.op.1',
                            'shared.0.chain.1.op.0',
                            'shared.3.chain.0.op.1',
                            'shared.3.chain.2.op.1',
                            'shared.1.chain.1.op.0',
                            'shared.1.chain.1.op.1',
                            'shared.2.chain.1.op.1',
                            'shared.4.chain.0.op.1',
                            'shared.2.chain.1.op.0',
                            'shared.5.chain.0.op.1',
                            'shared.3.chain.1.op.1',
                            'shared.1.chain.0.op.0',
                            'shared.4.chain.1.op.0',
                            'shared.4.chain.1.op.1',
                            'shared.5.chain.1.op.1',
                            'shared.5.chain.1.op.0',
                            'shared.1.chain.2.op.0',
                            'shared.3.chain.1.op.0',
                            'shared.2.chain.2.op.0',
                            'shared.3.chain.0.op.0',
                            'shared.2.chain.0.op.0',
                            'shared.3.chain.2.op.0',
                            'shared.4.chain.0.op.0',
                            'shared.5.chain.0.op.0',
                            'shared.5.chain.2.op.0',
                            'shared.4.chain.2.op.0']
            
            task_branches = task_branches[:topK]

        self._obtian_td_layers()
        self.turn(task_branches)

        self._initialize_weights()

        if pretrain:
            self._load_imagenet_weights()

    def turn(self, task_branches=[]):
        self._shared_parameters.clear()

        if self.n_tasks > 1:
            for idx, m in self.named_modules():
                if idx in task_branches:
                    m.set_n_tasks(n_tasks=self.n_tasks)

        task_branches = task_branches + self.td_layers

        # obtain the shared parameters
        for idx, m in self.named_modules():
            if '.m_list' in idx:
                idx = idx[:idx.index('.m_list')]

            if idx not in task_branches:
                members = m._parameters.items()
                memo = set()
                for k, v in members:
                    if v is None or v in memo:
                        continue
                    memo.add(v)
                    name = idx + ('.' if idx else '') + k
                    self._shared_parameters[name] = v
            else:
                members = m._parameters.items()
                memo = set()
                for k, v in members:
                    if v is None or v in memo:
                        continue
                    memo.add(v)
                    name = idx + ('.' if idx else '') + k
                    self._task_specific_parameters[name] = v

    def shared_parameters(self):
        return self._shared_parameters

    def zero_grad_shared_modules(self):
        for name, p in self.shared_parameters().items():
            if p.grad is not None:
                if p.grad.grad_fn is not None:
                    p.grad.detach_()
                else:
                    p.grad.requires_grad_(False)
                p.grad.zero_()

    def _obtian_td_layers(self):
        task_dependent_modules_names = ['decoder', 'encoder']

        self.td_layers = []
        for idx, m in self.named_modules():
            members = m._parameters.items()
            memo = set()
            for k, v in members:
                if v is None or v in memo:
                    continue
                memo.add(v)
                name = idx + ('.' if idx else '') + k

                if is_father_str(name, task_dependent_modules_names):
                    self.td_layers.append(idx)

        self.td_layers = list(set(self.td_layers))

    def model_size(self):
        param_size = 0
        for param in self.parameters():
            param_size += param.nelement() * param.element_size()

        size_all_mb = param_size / 1024 ** 2
        return size_all_mb

    def _get_branch_mappings(self):
        """
        Calculates branch mappings from the branch_config. `mappings` is a list of dicts mapping
        the index of an input branch to indices of output branches. For example:
        mappings = [
            {0: [0]},
            {0: [0, 1]},
            {0: [0, 2], 1: [1, 3]},
            {0: [1], 1: [2], 2: [3], 3: [0]}
        ]
        """

        def get_partition(layer_config):
            partition = []
            for t in range(len(self.tasks)):
                s = {i for i, x in enumerate(layer_config) if x == t}
                if not len(s) == 0:
                    partition.append(s)
            return partition

        def make_refinement(partition, ancestor):
            """ make `partition` a refinement of `ancestor` """
            refinement = []
            for part_1 in partition:
                for part_2 in ancestor:
                    inter = part_1.intersection(part_2)
                    if not len(inter) == 0:
                        refinement.append(inter)
            return refinement

        task_grouping = [set(range(len(self.tasks))), ]

        refinement_all = []
        for layer_idx, layer_config in enumerate(self.branch_config):

            partition = get_partition(layer_config)
            refinement = make_refinement(partition, task_grouping)
            refinement_all.append(refinement)

        return refinement_all

    def _get_output_key_with_task_id(self, keys_list, id):
        out = []
        for keys in keys_list:
            if id in keys:
                out.append(keys)
        assert len(out) == 1
        return out[0]



    def forward(self, x):
        input_shape = x.shape[-2:]
        x = self.stem(x)

        for i, layer in enumerate(self.shared):
            x = layer(x)
            if i == 2:
                out_low = modules_fw.clone_fw(x)

        if len(x) == 1:
            x = {frozenset([i for i in range(self.n_tasks)]): x[0]}
        else:
            x = {i: x[i] for i in range(len(x))}

        if len(out_low) == 1:
            out_low = {frozenset([i for i in range(self.n_tasks)]): out_low[0]}
        else:
            out_low = {i: out_low[i] for i in range(len(out_low))}

        for i, layer in enumerate(self.encoder):
            x = layer(x)


        output = {}

        for task in self.tasks:
            id = self.x_decoder_mapping[task]
            if len(x) < self.n_tasks:
                x_out = self._get_output_key_with_task_id(x.keys(), id)
                x_low_out = self._get_output_key_with_task_id(out_low.keys(), id)
                output[task] = self.decoder[task](x[x_out], out_low[x_low_out], input_shape)
            else:
                output[task] = self.decoder[task](x[id], out_low[id], input_shape)

        return output

    def _initialize_weights(self):
        # weight initialization
        for name, m in self.named_modules():
            if isinstance(m, nn.Conv2d):
                if 'logits' in name:
                    # initialize final prediction layer with fixed std
                    nn.init.normal_(m.weight, mean=0, std=0.01)
                else:
                    nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                if m.affine:
                    nn.init.ones_(m.weight)
                    nn.init.zeros_(m.bias)

    def _load_imagenet_weights(self):
        # we are using pretrained weights from torchvision.models:
        source_state_dict = torch.hub.load_state_dict_from_url(
            MODEL_URL, progress=False)
        target_state_dict = self.state_dict()

        mapped_state_dict = {}
        for name_trg in target_state_dict:
            ID_RASPP = 17 - self.shared_flag
            if name_trg.startswith('decoder') or name_trg.startswith(f'encoder.{ID_RASPP}'):
                continue  # can't load decoder and ASPP
            if name_trg.startswith('stem'):
                parsed = name_trg.split('.')
                pos = parsed.index('m_list')
                del parsed[pos: pos + 2]
                parsed = '.'.join(parsed)
                name_src = parsed.replace('stem.op', 'features.0')

            # this is highly specific to the current naming
            elif name_trg.startswith('shared'):
                parsed = name_trg.split('.')

                pos = parsed.index('m_list')
                del parsed[pos: pos + 2]

                layer_nr = int(parsed[1])
                chain_nr = int(parsed[3])
                op_nr = int(parsed[5])

                parsed[0] = 'features'
                parsed[1] = str(layer_nr + 1)
                parsed[2] = 'conv'

                if chain_nr == 0 or (chain_nr == 1 and layer_nr != 0):
                    parsed[3] = str(chain_nr)
                    del parsed[4]
                else:
                    parsed[3] = str(chain_nr + op_nr)
                    del parsed[4:6]

                name_src = '.'.join(parsed)

            elif name_trg.startswith('encoder'):
                parsed = name_trg.split('.')

                layer_nr = int(parsed[1]) + self.shared_flag
                chain_nr = int(parsed[5])
                op_nr = int(parsed[7])
                del parsed[2:4]  # remove the path and its index
                parsed[0] = 'features'
                parsed[1] = str(layer_nr + 1)
                parsed[2] = 'conv'

                if chain_nr == 0 or (chain_nr == 1 and layer_nr != 0):
                    parsed[3] = str(chain_nr)
                    del parsed[4]
                else:
                    parsed[3] = str(chain_nr + op_nr)
                    del parsed[4:6]
                name_src = '.'.join(parsed)
            else:
                raise ValueError
            mapped_state_dict[name_trg] = source_state_dict[name_src]
        self.load_state_dict(mapped_state_dict, strict=False)

class BranchMobileNetV2(nn.Module):
    """
    Branched multi-task network based on a MobileNetV2 backbone, R-ASPP module and
    DeepLabv3+ head.
    branch_config: Array (or nested list) of shape (n_layers X n_tasks) expressing which
    blocks are sampled for each task at each layer. Example (n_layers=6, n_tasks=4):
    branch_config = [
        [0, 0, 0, 0],
        [2, 2, 2, 2],
        [1, 1, 2, 2],
        [0, 0, 2, 3],
        [0, 1, 2, 3],
        [1, 1, 3, 3]
    ]
    This array determines the branching configuration.
    """

    def __init__(self, tasks, branch_config, pretrain=True):

        super().__init__()
        self.tasks = tasks
        self.branch_config = branch_config

        self.stem = modules.ConvBNReLU(in_channels=3,
                                       out_channels=32,
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       activation='relu6')

        mappings = self._get_branch_mappings()

        blocks = [
            modules.InvertedResidual(
                32, 16, 1, 1, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                16, 24, 2, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                24, 24, 1, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                24, 32, 2, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                32, 32, 1, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                32, 32, 1, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                32, 64, 2, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                64, 64, 1, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                64, 64, 1, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                64, 64, 1, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                64, 96, 1, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                96, 96, 1, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                96, 96, 1, 6, dilation=1, activation='relu6'),
            modules.InvertedResidual(
                96, 160, 2, 6, dilation=2, activation='relu6'),
            modules.InvertedResidual(
                160, 160, 1, 6, dilation=2, activation='relu6'),
            modules.InvertedResidual(
                160, 160, 1, 6, dilation=2, activation='relu6'),
            modules.InvertedResidual(
                160, 320, 1, 6, dilation=2, activation='relu6'),
            modules.RASPP(320, 128, activation='relu6', drop_rate=0.1)
        ]

        self.encoder = nn.Sequential(*[modules.BranchedLayer(
            bl, ma) for bl, ma in zip(blocks, mappings)])
        # self.encoder = nn.Sequential(*blocks)

        self.decoder = nn.ModuleDict({
            task: modules.DeepLabV3PlusDecoder(
                in_channels_low=24,
                out_f_classifier=128,
                use_separable=True,
                activation='relu6',
                num_outputs=TASK_CHANNEL_MAPPING[task])
            for task in self.tasks})

        self._initialize_weights()

        if pretrain:
            self._load_imagenet_weights()

    def model_size(self):
        param_size = 0
        for param in self.parameters():
            param_size += param.nelement() * param.element_size()

        size_all_mb = param_size / 1024 ** 2
        return size_all_mb

    def _get_branch_mappings(self):
        """
        Calculates branch mappings from the branch_config. `mappings` is a list of dicts mapping
        the index of an input branch to indices of output branches. For example:
        mappings = [
            {0: [0]},
            {0: [0, 1]},
            {0: [0, 2], 1: [1, 3]},
            {0: [1], 1: [2], 2: [3], 3: [0]}
        ]
        """

        def get_partition(layer_config):
            partition = []
            for t in range(len(self.tasks)):
                s = {i for i, x in enumerate(layer_config) if x == t}
                if not len(s) == 0:
                    partition.append(s)
            return partition

        def make_refinement(partition, ancestor):
            """ make `partition` a refinement of `ancestor` """
            refinement = []
            for part_1 in partition:
                for part_2 in ancestor:
                    inter = part_1.intersection(part_2)
                    if not len(inter) == 0:
                        refinement.append(inter)
            return refinement

        task_grouping = [set(range(len(self.tasks))), ]

        mappings = []
        for layer_idx, layer_config in enumerate(self.branch_config):

            partition = get_partition(layer_config)
            refinement = make_refinement(partition, task_grouping)

            out_dict = {}
            for prev_idx, prev in enumerate(task_grouping):
                out_dict[prev_idx] = []
                for curr_idx, curr in enumerate(refinement):
                    if curr.issubset(prev):
                        out_dict[prev_idx].append(curr_idx)

            task_grouping = refinement
            mappings.append(out_dict)

            if layer_idx == 2:
                self.x_low_decoder_mapping = {
                    task: [t_idx in gr for gr in task_grouping].index(True)
                    for t_idx, task in enumerate(self.tasks)
                }

        self.x_decoder_mapping = {
            task: [t_idx in gr for gr in task_grouping].index(True)
            for t_idx, task in enumerate(self.tasks)
        }
        return mappings

    def forward(self, x):
        input_shape = x.shape[-2:]
        x = self.stem(x)

        out = {0: x}
        for i, layer in enumerate(self.encoder):
            out = layer(out)
            if i == 2:
                out_low = {b: out[b].clone() for b in out.keys()}

        output = {task: self.decoder[task](out[self.x_decoder_mapping[task]],
                                           out_low[self.x_low_decoder_mapping[task]],
                                           input_shape)
                  for task in self.tasks}
        return output

    def _initialize_weights(self):
        # weight initialization
        for name, m in self.named_modules():
            if isinstance(m, nn.Conv2d):
                if 'logits' in name:
                    # initialize final prediction layer with fixed std
                    nn.init.normal_(m.weight, mean=0, std=0.01)
                else:
                    nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                if m.affine:
                    nn.init.ones_(m.weight)
                    nn.init.zeros_(m.bias)

    def _load_imagenet_weights(self):
        # we are using pretrained weights from torchvision.models:
        source_state_dict = torch.hub.load_state_dict_from_url(
            MODEL_URL, progress=False)
        target_state_dict = self.state_dict()

        mapped_state_dict = {}
        for name_trg in target_state_dict:
            if name_trg.startswith('decoder') or name_trg.startswith('encoder.17'):
                continue  # can't load decoder and ASPP
            if name_trg.startswith('stem'):
                name_src = name_trg.replace('stem.op', 'features.0')
            # this is highly specific to the current naming
            elif name_trg.startswith('encoder'):
                parsed = name_trg.split('.')
                layer_nr = int(parsed[1])
                chain_nr = int(parsed[5])
                op_nr = int(parsed[7])
                del parsed[2:4]  # remove the path and its index
                parsed[0] = 'features'
                parsed[1] = str(layer_nr + 1)
                parsed[2] = 'conv'
                if chain_nr == 0 or (chain_nr == 1 and layer_nr != 0):
                    parsed[3] = str(chain_nr)
                    del parsed[4]
                else:
                    parsed[3] = str(chain_nr + op_nr)
                    del parsed[4:6]
                name_src = '.'.join(parsed)
            else:
                raise ValueError
            mapped_state_dict[name_trg] = source_state_dict[name_src]
        self.load_state_dict(mapped_state_dict, strict=False)

class SuperMobileNetV2(nn.Module):
    """
    Supergraph encompassing all possible branched multi-task networks, based on a MobileNetV2
    encoder, R-ASPP module and DeepLabv3+ head. The branch configuration distribution is
    parameterized with architecture parameters self.alphas and relaxed for optimization using
    a Gumbel-Softmax.
    """

    def __init__(self, tasks, pretrain=True):

        super().__init__()
        self.tasks = tasks

        # First conv
        self.stem = modules.ConvBNReLU(in_channels=3,
                                       out_channels=32,
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       activation='relu6')

        # Build encoder supergraph
        blocks = [
            modules.InvertedResidual(
                32, 16, 1, 1, dilation=1, activation='relu6', final_affine=False),
            modules.InvertedResidual(
                16, 24, 2, 6, dilation=1, activation='relu6', final_affine=False),
            modules.InvertedResidual(
                24, 24, 1, 6, dilation=1, activation='relu6', final_affine=False),
            modules.InvertedResidual(
                24, 32, 2, 6, dilation=1, activation='relu6', final_affine=False),
            modules.InvertedResidual(
                32, 32, 1, 6, dilation=1, activation='relu6', final_affine=False),
            modules.InvertedResidual(
                32, 32, 1, 6, dilation=1, activation='relu6', final_affine=False),
            modules.InvertedResidual(
                32, 64, 2, 6, dilation=1, activation='relu6', final_affine=False),
            modules.InvertedResidual(
                64, 64, 1, 6, dilation=1, activation='relu6', final_affine=False),
            modules.InvertedResidual(
                64, 64, 1, 6, dilation=1, activation='relu6', final_affine=False),
            modules.InvertedResidual(
                64, 64, 1, 6, dilation=1, activation='relu6', final_affine=False),
            modules.InvertedResidual(
                64, 96, 1, 6, dilation=1, activation='relu6', final_affine=False),
            modules.InvertedResidual(
                96, 96, 1, 6, dilation=1, activation='relu6', final_affine=False),
            modules.InvertedResidual(
                96, 96, 1, 6, dilation=1, activation='relu6', final_affine=False),
            modules.InvertedResidual(
                96, 160, 2, 6, dilation=2, activation='relu6', final_affine=False),
            modules.InvertedResidual(
                160, 160, 1, 6, dilation=2, activation='relu6', final_affine=False),
            modules.InvertedResidual(
                160, 160, 1, 6, dilation=2, activation='relu6', final_affine=False),
            modules.InvertedResidual(
                160, 320, 1, 6, dilation=2, activation='relu6', final_affine=False),
            modules.RASPP(320, 128, activation='relu6',
                          final_affine=False, drop_rate=0.1)
        ]
        self.encoder = nn.ModuleList(modules.SupernetLayer(
            bl, len(self.tasks)) for bl in blocks)

        # Build decoder
        self.decoder = nn.ModuleDict({
            task: modules.DeepLabV3PlusDecoder(
                in_channels_low=24,
                out_f_classifier=128,
                use_separable=True,
                activation='relu6',
                num_outputs=TASK_CHANNEL_MAPPING[task])
            for task in self.tasks})

        self.warmup_flag = False
        self.gumbel_temp = None
        self.gumbel_func = modules.GumbelSoftmax(dim=-1, hard=False)

        self._create_alphas()
        self._initialize_weights()

        if pretrain:
            self._load_imagenet_weights()

    def forward(self, x, task):
        input_shape = x.shape[-2:]
        t_idx = self.tasks.index(task)

        t_feat = self.stem(x)
        for idx, op in enumerate(self.encoder):
            if self.warmup_flag:
                op_weights = torch.eye(len(self.tasks), out=torch.empty_like(
                    x), requires_grad=False)[t_idx, :]
            else:
                op_weights = self.gumbel_func(
                    self.alphas[idx][t_idx, :], temperature=self.gumbel_temp)
            t_feat = op(t_feat, op_weights)
            if idx == 2:
                t_interm = t_feat.clone()

        res = {task: self.decoder[task](t_feat, t_interm, input_shape)}
        return res

    def _initialize_weights(self):
        # weight initialization
        for name, m in self.named_modules():
            if isinstance(m, nn.Conv2d):
                if 'logits' in name:
                    # initialize final prediction layer with fixed std
                    nn.init.normal_(m.weight, mean=0, std=0.01)
                else:
                    nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                if m.affine:
                    nn.init.ones_(m.weight)
                    nn.init.zeros_(m.bias)

    def freeze_encoder_bn_running_stats(self):
        for m in self.encoder.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.track_running_stats = False

    def unfreeze_encoder_bn_running_stats(self):
        for m in self.encoder.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.track_running_stats = True

    def reset_encoder_bn_running_stats(self):
        for m in self.encoder.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.reset_running_stats()

    def _create_alphas(self):
        """ Create the architecture parameters of the supergraph. """
        n_tasks = len(self.tasks)
        self.alphas = nn.ParameterList([
            nn.Parameter(torch.zeros(n_tasks, n_tasks)) for _ in self.encoder
        ])

    def _load_imagenet_weights(self):
        # we are using pretrained weights from torchvision.models:
        source_state_dict = torch.hub.load_state_dict_from_url(
            MODEL_URL, progress=False)
        target_state_dict = self.state_dict()

        mapped_state_dict = {}
        for name_trg in target_state_dict:
            if name_trg.startswith('decoder') or name_trg.startswith('encoder.17'):
                continue  # can't load decoder and ASPP
            if name_trg.startswith('alphas'):
                continue
            if name_trg.startswith('stem'):
                name_src = name_trg.replace('stem.op', 'features.0')
            # this is highly specific to the current naming
            elif name_trg.startswith('encoder'):
                parsed = name_trg.split('.')
                layer_nr = int(parsed[1])
                chain_nr = int(parsed[5])
                op_nr = int(parsed[7])
                del parsed[2:4]  # remove the path and its index
                parsed[0] = 'features'
                parsed[1] = str(layer_nr + 1)
                parsed[2] = 'conv'
                if chain_nr == 0 or (chain_nr == 1 and layer_nr != 0):
                    parsed[3] = str(chain_nr)
                    del parsed[4]
                else:
                    parsed[3] = str(chain_nr + op_nr)
                    del parsed[4:6]
                name_src = '.'.join(parsed)
            else:
                raise ValueError
            mapped_state_dict[name_trg] = source_state_dict[name_src]
        self.load_state_dict(mapped_state_dict, strict=False)

    def weight_parameters(self):
        for name, param in self.named_weight_parameters():
            yield param

    def named_weight_parameters(self):
        return filter(lambda x: not x[0].startswith('alphas'),
                      self.named_parameters())

    def arch_parameters(self):
        for name, param in self.named_arch_parameters():
            yield param

    def named_arch_parameters(self):
        return filter(lambda x: x[0].startswith('alphas'),
                      self.named_parameters())

    def get_branch_config(self):
        n_blocks = len(self.alphas)
        n_tasks = len(self.tasks)
        branch_config = torch.empty(n_blocks, n_tasks, device='cpu')
        for b in range(n_blocks):
            alpha_probs = nn.functional.softmax(
                self.alphas[b], dim=-1).to('cpu').detach()
            for t in range(n_tasks):
                branch_config[b, t] = torch.argmax(alpha_probs[t, :])
        return branch_config.numpy().tolist()

    def calculate_flops_lut(self, file_name, input_size):
        shared_config = torch.zeros(18, len(self.tasks))
        model = BranchMobileNetV2(tasks=['semseg'],
                                  branch_config=shared_config)
        in_shape = (1, 3, input_size[0], input_size[1])
        n_blocks = len(model.encoder)
        flops = torch.zeros(n_blocks, device='cpu')

        model.eval()
        with torch.no_grad():
            for idx, m in enumerate(model.encoder):
                m = resources.add_flops_counting_methods(m)
                m.start_flops_count()
                cache_inputs = torch.rand(in_shape)
                _ = model(cache_inputs)
                block_flops = m.compute_average_flops_cost()
                m.stop_flops_count()
                flops[idx] = block_flops
        flops_dict = {'per_block_flops': flops.numpy().tolist()}
        del model

        # save the FLOPS to LUT
        utils.write_json(flops_dict, file_name)

        return flops_dict

    def get_flops(self):
        input_size = [512, 512]  # for PASCAL-Context

        filename = Path('flops_MobileNetV2_{}_{}.json'.format(
            input_size[0], input_size[1]))
        if filename.is_file():
            flops_dict = utils.read_json(filename)
        else:
            print('no LUT found, calculating FLOPS...')
            flops_dict = self.calculate_flops_lut(filename, input_size)
        return flops_dict['per_block_flops']

if __name__ == '__main__':
    import numpy as np
    import random

    # branch_config = utils.read_json('../exp_semseg-human_parts-sal-normals-edge_0.1_True_09-06-15-59-19/search/branch_config.json')['config']
    # branch_config = utils.read_json('../exp_semseg-human_parts-sal-normals_0.1_True_130_09-12-18-33-02/search/branch_config.json')['config']
    # input = torch.rand(2, 3, 512, 512).cuda()
    #
    # net = BranchMobileNetV2_fw_v2(tasks=['semseg','human_parts','sal','normals'],  branch_config=branch_config, branched='branch', topK=10).cuda()
    # out = net(input)

    size_noraml = MoblieNetV2_fw(tasks=['semseg', 'human_parts', 'sal', 'normals'], branched='empty').model_size()

    net = MoblieNetV2_fw(tasks=['semseg','human_parts','sal','normals'], branched='branch', topK=85)
    size_recon= net.model_size()


    configuration = '../exp_semseg-human_parts-sal-normals_0.1_True_130_09-12-18-33-02/search/branch_config.json'
    branch_config = utils.read_json(configuration)['config']
    size_bmtas = BranchMobileNetV2_fw_v2(tasks=['semseg','human_parts','sal','normals'],  branch_config=branch_config, branched='empty').model_size()

    size_bmtas_recon = BranchMobileNetV2_fw_v2(tasks=['semseg','human_parts','sal','normals'],  branch_config=branch_config, branched='branch', topK=85).model_size()

    print(f'nomral size: {size_noraml:.2f}')
    print(f'recon size: {size_recon:.2f}')
    print(f'bmtas size: {size_bmtas:.2f}')
    print(f'bmtas recon size: {size_bmtas_recon:.2f}')

    semseg_size = MoblieNetV2_fw(tasks=['semseg'], branched='empty').model_size()
    human_parts = MoblieNetV2_fw(tasks=['human_parts'], branched='empty').model_size()
    sal = MoblieNetV2_fw(tasks=['sal'], branched='empty').model_size()
    normals = MoblieNetV2_fw(tasks=['normals'], branched='empty').model_size()

    print(f'single size: {semseg_size + human_parts + sal + normals}')

    # out = net(input)

    # net2 = BranchMobileNetV2(tasks=['semseg','human_parts','sal','normals','edge'], branch_config=branch_config)

    # print(f'net1: {net.model_size()}')
    # print(f'net2: {net2.model_size()}')

    #
    # torch.backends.cudnn.benchmark = False
    # torch.backends.cudnn.deterministic = True
    #
    # torch.manual_seed(0)
    # np.random.seed(0)
    # random.seed(0)
    # torch.cuda.manual_seed_all(0)
    # torch.set_num_threads(4)
    #
    # input = torch.rand(2, 3, 512, 512).cuda()
    # net1 = MoblieNetV2(tasks=['semseg','human_parts','sal','normals','edge']).cuda()
    # net2 = MoblieNetV2_fw(tasks=['semseg','human_parts','sal','normals','edge']).cuda()
    #
    # v1 = []
    # v2 = []
    # keys = []
    # for i, (key, value) in enumerate(net1.named_parameters()):
    #     v1.append(value)
    #     keys.append(key)
    # print('-------------------------------------------------')
    # for i, (key, value) in enumerate(net2.named_parameters()):
    #     v2.append(value)
    #
    # diff = []
    # for i in range(len(v1)):
    #     zeros = v1[i] - v2[i]
    #     zeros = zeros.sum()
    #     if zeros != 0:
    #         diff.append(i)
    #
    # for i, (key, value) in enumerate(net2.named_parameters()):
    #     if i in diff:
    #         value.data = torch.clone(v1[i])
    #
    # v1 = []
    # v2 = []
    # keys = []
    # for i, (key, value) in enumerate(net1.named_parameters()):
    #     v1.append(value)
    #     keys.append(key)
    # print('-------------------------------------------------')
    # for i, (key, value) in enumerate(net2.named_parameters()):
    #     v2.append(value)
    #
    # diff = []
    # for i in range(len(v1)):
    #     zeros = v1[i] - v2[i]
    #     zeros = zeros.sum()
    #     print(zeros)
    #
    # output = net1(input)
    #
    # output2 = net2(input)

    print('Finished')