import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from collections import OrderedDict

# task_branches = []


# def is_branches():


class Linear_fw(nn.Module):  # used in MAML to forward input with fast weight
    def __init__(self, in_features, out_features, n_tasks=1):
        super(Linear_fw, self).__init__()
        self.n_tasks = n_tasks
        self.m_list = nn.ModuleList([nn.Linear(in_features, out_features) for i in range(n_tasks)])

    def set_n_tasks(self, n_tasks=1):
        if n_tasks >= self.n_tasks:
            gap = n_tasks - self.n_tasks
            self.n_tasks = n_tasks
            for i in range(gap):
                module = deepcopy(self.m_list[0])
                self.m_list.append(module)
        else:
            raise ValueError('Can not decrease the number of tasks in fw module')


    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        out = []
        if len(x) == 1 and self.n_tasks > 1:
            for i, ln in enumerate(self.m_list):
                o = ln(x[0])
                out.append(o)
        elif len(x) > 1 and self.n_tasks == 1:
            for i, x_i in enumerate(x):
                o = self.m_list[0](x_i)
                out.append(o)
        elif len(x) == self.n_tasks:
            for i, ln in enumerate(self.m_list):
                o = ln(x[i])
                out.append(o)
        else:
            raise ValueError('Error')

        return out

class Conv2d_fw(nn.Module):  # used in MAML to forward input with fast weight
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, groups=1, dilation=1, n_tasks=1):
        super(Conv2d_fw, self).__init__()
        self.n_tasks = n_tasks
        self.m_list = nn.ModuleList(
            [nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) for i in range(n_tasks)])
    
    def set_n_tasks(self, n_tasks=1):
        if n_tasks >= self.n_tasks:
            gap = n_tasks - self.n_tasks
            self.n_tasks = n_tasks
            for i in range(gap):
                module = deepcopy(self.m_list[0])
                self.m_list.append(module)
        else:
            raise ValueError('Can not decrease the number of tasks in fw module')
    
    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        out = []
        if len(x) == 1 and self.n_tasks > 1:
            for i, conv in enumerate(self.m_list):
                o = conv(x[0])
                out.append(o)
        elif len(x) > 1 and self.n_tasks == 1:
            for i, x_i in enumerate(x):
                o = self.m_list[0](x_i)
                out.append(o)
        elif len(x) == self.n_tasks:
            for i, conv in enumerate(self.m_list):
                o = conv(x[i])
                out.append(o)
        else:
            raise ValueError('Error')
        return out

class BatchNorm2d_fw(nn.Module):  # used in MAML to forward input with fast weight
    def __init__(self, num_features, n_tasks=1):
        super(BatchNorm2d_fw, self).__init__()
        self.n_tasks = n_tasks
        self.m_list = nn.ModuleList([nn.BatchNorm2d(num_features) for i in range(n_tasks)])

    def set_n_tasks(self, n_tasks=1):
        if n_tasks >= self.n_tasks:
            gap = n_tasks - self.n_tasks
            self.n_tasks = n_tasks
            for i in range(gap):
                module = deepcopy(self.m_list[0])
                self.m_list.append(module)
        else:
            raise ValueError('Can not decrease the number of tasks in fw module')

    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        out = []
        if len(x) == 1 and self.n_tasks > 1:
            for i, bn in enumerate(self.m_list):
                o = bn(x[0])
                out.append(o)
        elif len(x) > 1 and self.n_tasks == 1:
            for i, x_i in enumerate(x):
                o = self.m_list[0](x_i)
                out.append(o)
        elif len(x) == self.n_tasks:
            for i, bn in enumerate(self.m_list):
                o = bn(x[i])
                out.append(o)
        else:
            raise ValueError('Error')

        return out

class ReLU_fw(nn.ReLU):
    def __init__(self, inplace=False):
        super(ReLU_fw, self).__init__(inplace)

    def forward(self, x):
        if not isinstance(x, list):
            x = [x]
            
        out = []
        for x_i in x:
            out.append(super(ReLU_fw, self).forward(x_i))

        return out

class Sigmoid_fw(nn.Sigmoid):
    def __init__(self):
        super(Sigmoid_fw, self).__init__()

    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        out = []
        for x_i in x:
            out.append(super(Sigmoid_fw, self).forward(x_i))

        return out

class MaxPool2d_fw(nn.MaxPool2d):
    def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False):
        super(MaxPool2d_fw, self).__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode)

    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        if self.return_indices:
            out = []
            indices = []

            for x_i in x:
                o, i = super(MaxPool2d_fw, self).forward(x_i)
                out.append(o)
                indices.append(i)

            return out, indices
        else:
            out = []
            for x_i in x:
                out.append(super(MaxPool2d_fw, self).forward(x_i))

        return out

class MaxUnpool2d_fw(nn.MaxUnpool2d):
    def __init__(self, kernel_size, stride=None, padding=0):
        super(MaxUnpool2d_fw, self).__init__(kernel_size, stride, padding)

    def forward(self, x, indices):
        if not isinstance(x, list):
            x = [x]
        if not isinstance(x, list):
            indices = [indices]

        # assert len(x) == len(indices)

        if len(x) == len(indices):
            out = []
            for i, x_i in enumerate(x):
                out.append(super(MaxUnpool2d_fw, self).forward(x_i, indices[i]))
        elif len(x) > 1 and len(indices) == 1:
            out = []
            for i, x_i in enumerate(x):
                out.append(super(MaxUnpool2d_fw, self).forward(x_i, indices[0]))
        else:
            raise ValueError('Error!')

        return out

class AdaptiveAvgPool2d_fw(nn.AdaptiveAvgPool2d):
    def __init__(self, output_size):
        super(AdaptiveAvgPool2d_fw, self).__init__(output_size)

    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        out = []
        for x_i in x:
            out.append(super(AdaptiveAvgPool2d_fw, self).forward(x_i))

        return out

def dot_fw(x, y):
    if not isinstance(x, list):
        x = [x]
    if not isinstance(y, list):
        y = [y]

    out = []
    if len(x) == 1 and len(y) > 1:
        for i, y_i in enumerate(y):
            o = x[0] * y_i
            out.append(o)
    elif len(x) > 1 and len(y) == 1:
        for i, x_i in enumerate(x):
            o = x_i * y[0]
            out.append(o)
    elif len(x) == len(y):
        for i, x_i in enumerate(x):
            o = x_i * y[i]
            out.append(o)
    else:
        raise ValueError('Error')

    return out

def cat_fw(x, y, dim=0):
    if not isinstance(x, list):
        x = [x]
    if not isinstance(y, list):
        y = [y]

    out = []
    if len(x) == 1 and len(y) > 1:
        for i, y_i in enumerate(y):
            o = torch.cat((x[0], y_i), dim=dim)
            out.append(o)
    elif len(x) > 1 and len(y) == 1:
        for i, x_i in enumerate(x):
            o = torch.cat((x_i, y[0]), dim=dim)
            out.append(o)
    elif len(x) == len(y):
        for i, x_i in enumerate(x):
            o = torch.cat((x_i, y[i]), dim=dim)
            out.append(o)
    else:
        raise ValueError('Error')

    return out

def interpolate_fw(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
    if not isinstance(input, list):
        input = [input]

    out = []
    for x in input:
        o = F.interpolate(x, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
        out.append(o)

    return out

def is_father_str(str, str_list):
    for s in str_list:
        if s in str:
            return True
    return False

class SegNet_fw(nn.Module):
    def __init__(self, n_tasks=3, branched='empty', attention=True, topK=5, T=(True, True, True)):
        super(SegNet_fw, self).__init__()
        # initialise network parameters
        filter = [64, 128, 256, 512, 512]
        self.class_nb = 13
        self.n_tasks = n_tasks
        self.attention = attention

        # define encoder decoder layers
        self.encoder_block = nn.ModuleList([self.conv_layer([3, filter[0]])])
        self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
        for i in range(4):
            self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]]))
            self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]]))

        # define convolution layer
        self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
        self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
        for i in range(4):
            if i == 0:
                self.conv_block_enc.append(self.conv_layer([filter[i + 1], filter[i + 1]]))
                self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]]))
            else:
                self.conv_block_enc.append(nn.Sequential(self.conv_layer([filter[i + 1], filter[i + 1]]),
                                                         self.conv_layer([filter[i + 1], filter[i + 1]])))
                self.conv_block_dec.append(nn.Sequential(self.conv_layer([filter[i], filter[i]]),
                                                         self.conv_layer([filter[i], filter[i]])))

        if attention:
            # define task attention layers
            self.encoder_att = nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]], n_tasks=n_tasks)])
            self.decoder_att = nn.ModuleList([self.att_layer([2 * filter[0], filter[0], filter[0]], n_tasks=n_tasks)])

            for i in range(4):
                self.encoder_att.append(self.att_layer([2 * filter[i + 1], filter[i + 1], filter[i + 1]], n_tasks=n_tasks))
                self.decoder_att.append(self.att_layer([filter[i + 1] + filter[i], filter[i], filter[i]], n_tasks=n_tasks))

            self.encoder_block_att = nn.ModuleList([self.conv_layer([filter[0], filter[1]])])
            self.decoder_block_att = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
            for i in range(4):
                if i < 3:
                    self.encoder_block_att.append(self.conv_layer([filter[i + 1], filter[i + 2]]))
                    self.decoder_block_att.append(self.conv_layer([filter[i + 1], filter[i]]))
                else:
                    self.encoder_block_att.append(self.conv_layer([filter[i + 1], filter[i + 1]]))
                    self.decoder_block_att.append(self.conv_layer([filter[i + 1], filter[i + 1]]))


        if T[0]:
            self.pred_task1 = self.conv_layer([filter[0], self.class_nb], pred=True)

        if T[1]:
            self.pred_task2 = self.conv_layer([filter[0], 1], pred=True)

        if T[2]:
            self.pred_task3 = self.conv_layer([filter[0], 3], pred=True)

        # define pooling and unpooling functions
        self.max_pool = MaxPool2d_fw(kernel_size=2, stride=2)

        self.down_sampling = MaxPool2d_fw(kernel_size=2, stride=2, return_indices=True)
        self.up_sampling = MaxUnpool2d_fw(kernel_size=2, stride=2)

        self.logsigma = nn.Parameter(torch.FloatTensor([-0.5, -0.5, -0.5]))

        self._shared_parameters = OrderedDict()

        # obtain the defined task dependent layers
        td_layers = []
        for idx, m in self.named_modules():
            members = m._parameters.items()
            memo = set()
            for k, v in members:
                if v is None or v in memo:
                    continue
                memo.add(v)
                name = idx + ('.' if idx else '') + k

                if is_father_str(name, self.task_dependent_modules_names()):
                    td_layers.append(name[:idx.index('.m_list')])

        self.td_layers = list(set(td_layers))

        # insert task branches
        if branched == 'empty':
            task_branches = []
        elif branched == 'ablation':
            # task_branches = ['encoder_block.0.bn1', 'encoder_block.3.bn1', 'encoder_block.1.bn1', 'encoder_block.2.bn1', 'encoder_block_att.0.bn1',
            #                  'conv_block_enc.3.0.bn1', 'conv_block_enc.2.0.bn1', 'encoder_block.4.bn1', 'decoder_block.4.bn1', 'conv_block_enc.1.bn1:']
            task_branches = ['encoder_block_att.2.conv1', 'encoder_block_att.1.conv1', 'encoder_block_att.3.conv1', 'decoder_block_att.0.conv1', 'decoder_block_att.3.conv1']
            task_branches = task_branches[:topK]
        elif branched == 'branch':
            

            # lw_task_cos_flood_0d5_ep50
            # task_branches = ['encoder_block_att.0.bn1', 'encoder_block_att.0.conv1', 'encoder_block_att.1.bn1', 'encoder_block_att.1.conv1', 'encoder_block_att.2.bn1',
            #                  'conv_block_dec.0.bn1', 'encoder_block_att.2.conv1', 'encoder_block_att.3.bn1', 'encoder_block_att.3.conv1', 'conv_block_dec.0.conv1']

            # lw_task nothing L=0.0
            # task_branches = ['decoder_block_att.1.bn1',
            #                 'decoder_block_att.2.conv1',
            #                 'decoder_block_att.1.conv1',
            #                 'decoder_block_att.3.conv1',
            #                 'encoder_block_att.0.bn1',
            #                 'conv_block_dec.0.bn1',
            #                 'encoder_block_att.0.conv1',
            #                 'decoder_block_att.3.bn1',
            #                 'decoder_block_att.0.bn1',
            #                 'encoder_block_att.1.bn1',
            #                 'decoder_block_att.0.conv1',
            #                 'encoder_block_att.1.conv1',
            #                 'encoder_block_att.2.bn1',
            #                 'decoder_block_att.2.bn1',
            #                 'decoder_block_att.4.bn1',
            #                 'encoder_block_att.3.bn1',
            #                 'encoder_block_att.2.conv1',
            #                 'conv_block_dec.0.conv1',
            #                 'conv_block_dec.1.bn1',
            #                 'decoder_block.0.bn1',
            #                 'encoder_block_att.3.conv1',
            #                 'conv_block_dec.2.1.bn1',
            #                 'decoder_block.1.bn1',
            #                 'conv_block_dec.1.conv1']

            # lw_task nothing L = -0.02, epoch=50
            task_branches = ['conv_block_dec.0.bn1',
                            'decoder_block_att.1.bn1',
                            'decoder_block_att.0.bn1',
                            'encoder_block_att.0.bn1',
                            'decoder_block_att.0.conv1',
                            'decoder_block_att.1.conv1',
                            'encoder_block_att.1.bn1',
                            'decoder_block_att.3.bn1',
                            'conv_block_dec.1.bn1',
                            'decoder_block_att.2.bn1',
                            'decoder_block.0.bn1',
                            'encoder_block_att.2.bn1',
                            'decoder_block.2.bn1',
                            'decoder_block.1.bn1',
                            'conv_block_dec.2.0.bn1',
                            'conv_block_dec.0.conv1',
                            'decoder_block_att.4.bn1',
                            'conv_block_dec.2.1.bn1',
                            'encoder_block_att.3.bn1',
                            'conv_block_dec.1.conv1',
                            'decoder_block.0.conv1',
                            'decoder_block_att.2.conv1',
                            'decoder_block.3.bn1',
                            'conv_block_dec.3.0.bn1',
                            'conv_block_dec.3.1.bn1']

            task_branches = task_branches[:topK]
            # task_branches = ['encoder_block_att.0.bn1', 'encoder_block_att.0.conv1', 'encoder_block_att.3.bn1',
            #                  'encoder_block_att.3.conv1', 'encoder_block_att.2.conv1', 'encoder_block_att.1.conv1',
            #                  'encoder_block_att.2.bn1', 'encoder_block_att.3.bn1', 'decoder_block_att.0.bn1', 'decoder_block_att.3.bn1']

            # task_branches = ['encoder_block_att.3.bn1', 'encoder_block_att.2.bn1', 'encoder_block_att.3.conv1',
            #                  'encoder_block_att.1.bn1', 'encoder_block_att.2.conv1', 'encoder_block_att.1.conv1']

            task_branches = sorted(task_branches)
        else:
            raise ValueError('Not Defined!')

        self.turn(task_branches=task_branches)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def turn(self, task_branches=[]):
        self._shared_parameters.clear()

        if self.n_tasks > 1:
            for idx, m in self.named_modules():
                if idx in task_branches:
                    m.set_n_tasks(n_tasks=self.n_tasks)

        task_branches = task_branches + self.td_layers

        # obtain the shared parameters
        for idx, m in self.named_modules():
            if '.m_list' in idx:
                idx = idx[:idx.index('.m_list')]

            if idx not in task_branches:
                members = m._parameters.items()
                memo = set()
                for k, v in members:
                    if v is None or v in memo:
                        continue
                    memo.add(v)
                    name = idx + ('.' if idx else '') + k
                    self._shared_parameters[name] = v

        # delete logsigma
        del self._shared_parameters['logsigma']

    def shared_parameters(self):
        return self._shared_parameters

    # def shared_modules(self):
    #     return [self.encoder_block, self.decoder_block,
    #             self.conv_block_enc, self.conv_block_dec,
    #             # self.encoder_att, self.decoder_att,
    #             self.encoder_block_att, self.decoder_block_att,
    #             self.down_sampling, self.up_sampling]

    def shared_modules_name(self):
        return ['encoder_block', 'decoder_block',
                'conv_block_enc', 'conv_block_dec',
                'encoder_block_att', 'decoder_block_att',
                'down_sampling', 'up_sampling']

    def task_dependent_modules_names(self):
        return ['encoder_att', 'decoder_att',
                'pred_task1', 'pred_task2', 'pred_task3']

    def zero_grad_shared_modules(self):
        for name, p in self.shared_parameters().items():
            if p.grad is not None:
                if p.grad.grad_fn is not None:
                    p.grad.detach_()
                else:
                    p.grad.requires_grad_(False)
                p.grad.zero_()

    def add_branch_layers(self, layer_list):
        for idx, m in self.named_modules():
            if idx in layer_list:
                m.set_n_tasks(n_tasks=self.n_tasks)

    def conv_layer(self, channel, pred=False):
        if not pred:
            conv_block = nn.Sequential(
                OrderedDict([
                    ('conv1', Conv2d_fw(in_channels=channel[0], out_channels=channel[1], kernel_size=3, padding=1)),
                    ('bn1', BatchNorm2d_fw(num_features=channel[1])),
                    ('relu1', ReLU_fw(inplace=True))
                ])
            )
        else:
            conv_block = nn.Sequential(
                OrderedDict([
                    ('conv1', Conv2d_fw(in_channels=channel[0], out_channels=channel[0], kernel_size=3, padding=1)),
                    ('conv2', Conv2d_fw(in_channels=channel[0], out_channels=channel[1], kernel_size=1, padding=0))
                ])
            )
        return conv_block

    def att_layer(self, channel, n_tasks=1):
        att_block = nn.Sequential(
            OrderedDict([
                ('conv1', Conv2d_fw(in_channels=channel[0], out_channels=channel[1], kernel_size=1, padding=0, n_tasks=n_tasks)),
                ('bn1', BatchNorm2d_fw(channel[1], n_tasks=n_tasks)),
                ('relu1', ReLU_fw(inplace=True)),
                ('conv2', Conv2d_fw(in_channels=channel[1], out_channels=channel[2], kernel_size=1, padding=0, n_tasks=n_tasks)),
                ('bn2', BatchNorm2d_fw(channel[2], n_tasks=n_tasks)),
                ('Sigmoid2', Sigmoid_fw())
            ])
        )
        return att_block

    def model_size(self):
        param_size = 0
        for param in self.parameters():
            param_size += param.nelement() * param.element_size()

        size_all_mb = param_size / 1024 ** 2
        return size_all_mb

    def forward(self, x):
        g_encoder, g_decoder, g_maxpool, g_upsampl, indices = ([0] * 5 for _ in range(5))
        for i in range(5):
            g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2))

        # define attention list for tasks

        atten_encoder, atten_decoder = ([0] * 5 for _ in range(2))
        for j in range(5):
            atten_encoder[j], atten_decoder[j] = ([0] * 3 for _ in range(2))

        # define global shared network
        for i in range(5):
            if i == 0:
                g_encoder[i][0] = self.encoder_block[i](x)
                g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
                g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
            else:
                g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1])
                g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
                g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])

        for i in range(5):
            if i == 0:
                g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1])
                g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
                g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
            else:
                g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1])
                g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
                g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])


        if self.attention:
            for j in range(5):
                if j == 0:
                    atten_encoder[j][0] = self.encoder_att[j](g_encoder[j][0])
                    atten_encoder[j][1] = dot_fw(atten_encoder[j][0], g_encoder[j][1])
                    atten_encoder[j][2] = self.encoder_block_att[j](atten_encoder[j][1])
                    atten_encoder[j][2] = self.max_pool(atten_encoder[j][2])
                else:
                    atten_encoder[j][0] = self.encoder_att[j](
                        cat_fw(g_encoder[j][0], atten_encoder[j - 1][2], dim=1))
                    atten_encoder[j][1] = dot_fw(atten_encoder[j][0], g_encoder[j][1])
                    atten_encoder[j][2] = self.encoder_block_att[j](atten_encoder[j][1])
                    atten_encoder[j][2] = self.max_pool(atten_encoder[j][2])

            for j in range(5):
                if j == 0:
                    atten_decoder[j][0] = interpolate_fw(atten_encoder[-1][-1], scale_factor=2, mode='bilinear',
                                                           align_corners=True)
                    atten_decoder[j][0] = self.decoder_block_att[-j - 1](atten_decoder[j][0])
                    atten_decoder[j][1] = self.decoder_att[-j - 1](
                        cat_fw(g_upsampl[j], atten_decoder[j][0], dim=1))
                    atten_decoder[j][2] = dot_fw(atten_decoder[j][1], g_decoder[j][-1])

                else:
                    atten_decoder[j][0] = interpolate_fw(atten_decoder[j - 1][2], scale_factor=2, mode='bilinear',
                                                           align_corners=True)
                    atten_decoder[j][0] = self.decoder_block_att[-j - 1](atten_decoder[j][0])
                    atten_decoder[j][1] = self.decoder_att[-j - 1](
                        cat_fw(g_upsampl[j], atten_decoder[j][0], dim=1))
                    atten_decoder[j][2] = dot_fw(atten_decoder[j][1], g_decoder[j][-1])

            if len(atten_decoder[-1][-1]) == self.n_tasks:
                t1_pred = self.pred_task1(atten_decoder[-1][-1][0])[0]
                t2_pred = self.pred_task2(atten_decoder[-1][-1][1])[0]
                t3_pred = self.pred_task3(atten_decoder[-1][-1][2])[0]
            elif len(atten_decoder[-1][-1]) == 1:
                t1_pred = self.pred_task1(atten_decoder[-1][-1][0])[0]
                t2_pred = self.pred_task2(atten_decoder[-1][-1][0])[0]
                t3_pred = self.pred_task3(atten_decoder[-1][-1][0])[0]
            else:
                raise ValueError('Error')

        else:
            if len(g_decoder[-1][-1]) == self.n_tasks:
                t1_pred = self.pred_task1(g_decoder[-1][-1][0])[0]
                t2_pred = self.pred_task2(g_decoder[-1][-1][1])[0]
                t3_pred = self.pred_task3(g_decoder[-1][-1][2])[0]
            elif len(g_decoder[-1][-1]) == 1:
                t1_pred = self.pred_task1(g_decoder[-1][-1][0])[0]
                t2_pred = self.pred_task2(g_decoder[-1][-1][0])[0]
                t3_pred = self.pred_task3(g_decoder[-1][-1][0])[0]
            else:
                raise ValueError('Error')

        t3_pred = t3_pred / torch.norm(t3_pred, p=2, dim=1, keepdim=True)

        return [t1_pred, t2_pred, t3_pred], self.logsigma

if __name__ == '__main__':
    net = SegNet_fw(branched='empty', T=[False, False, False])
    model_size = net.model_size()
    print(f'noraml: {model_size}')
    net = SegNet_fw(branched='branch', topK=15, T=[False, False, False])
    model_size = net.model_size()
    print(f'RMConflict: {model_size}')
    net_t1 = SegNet_fw(branched='empty', attention=False, T=[True, False, False])
    net_t2 = SegNet_fw(branched='empty', attention=False, T=[False, True, False])
    net_t3 = SegNet_fw(branched='empty', attention=False, T=[False, False, True])

    m = net_t1.model_size() + net_t2.model_size() + net_t3.model_size()
    print(f'Single: {m}')

    # x = torch.rand(2, 3, 288, 384)
    # output = net(x)
