import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from collections import OrderedDict
import json
import random

# task_branches = []


# def is_branches():


class Linear_fw(nn.Module):  # used in MAML to forward input with fast weight
    def __init__(self, in_features, out_features, n_tasks=1):
        super(Linear_fw, self).__init__()
        self.n_tasks = n_tasks
        self.m_list = nn.ModuleList([nn.Linear(in_features, out_features) for i in range(n_tasks)])

    def set_n_tasks(self, n_tasks=1):
        if n_tasks >= self.n_tasks:
            gap = n_tasks - self.n_tasks
            self.n_tasks = n_tasks
            for i in range(gap):
                module = deepcopy(self.m_list[0])
                self.m_list.append(module)
        else:
            raise ValueError('Can not decrease the number of tasks in fw module')


    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        out = []
        if len(x) == 1 and self.n_tasks > 1:
            for i, ln in enumerate(self.m_list):
                o = ln(x[0])
                out.append(o)
        elif len(x) > 1 and self.n_tasks == 1:
            for i, x_i in enumerate(x):
                o = self.m_list[0](x_i)
                out.append(o)
        elif len(x) == self.n_tasks:
            for i, ln in enumerate(self.m_list):
                o = ln(x[i])
                out.append(o)
        else:
            raise ValueError('Error')

        return out

class Conv2d_fw(nn.Module):  # used in MAML to forward input with fast weight
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, groups=1, dilation=1, n_tasks=1):
        super(Conv2d_fw, self).__init__()
        self.n_tasks = n_tasks
        self.m_list = nn.ModuleList(
            [nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) for i in range(n_tasks)])
    
    def set_n_tasks(self, n_tasks=1):
        if n_tasks >= self.n_tasks:
            gap = n_tasks - self.n_tasks
            self.n_tasks = n_tasks
            for i in range(gap):
                module = deepcopy(self.m_list[0])
                self.m_list.append(module)
        else:
            raise ValueError('Can not decrease the number of tasks in fw module')
    
    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        out = []
        if len(x) == 1 and self.n_tasks > 1:
            for i, conv in enumerate(self.m_list):
                o = conv(x[0])
                out.append(o)
        elif len(x) > 1 and self.n_tasks == 1:
            for i, x_i in enumerate(x):
                o = self.m_list[0](x_i)
                out.append(o)
        elif len(x) == self.n_tasks:
            for i, conv in enumerate(self.m_list):
                o = conv(x[i])
                out.append(o)
        else:
            raise ValueError('Error')
        return out

class BatchNorm2d_fw(nn.Module):  # used in MAML to forward input with fast weight
    def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None, n_tasks=1):
        super(BatchNorm2d_fw, self).__init__()
        self.n_tasks = n_tasks
        self.m_list = nn.ModuleList([nn.BatchNorm2d(num_features, eps, momentum, affine,
                                                    track_running_stats, device, dtype) for i in range(n_tasks)])

    def set_n_tasks(self, n_tasks=1):
        if n_tasks >= self.n_tasks:
            gap = n_tasks - self.n_tasks
            self.n_tasks = n_tasks
            for i in range(gap):
                module = deepcopy(self.m_list[0])
                self.m_list.append(module)
        else:
            raise ValueError('Can not decrease the number of tasks in fw module')

    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        out = []
        if len(x) == 1 and self.n_tasks > 1:
            for i, bn in enumerate(self.m_list):
                o = bn(x[0])
                out.append(o)
        elif len(x) > 1 and self.n_tasks == 1:
            for i, x_i in enumerate(x):
                o = self.m_list[0](x_i)
                out.append(o)
        elif len(x) == self.n_tasks:
            for i, bn in enumerate(self.m_list):
                o = bn(x[i])
                out.append(o)
        else:
            raise ValueError('Error')

        return out

class ReLU_fw(nn.ReLU):
    def __init__(self, inplace=False):
        super(ReLU_fw, self).__init__(inplace)

    def forward(self, x):
        if not isinstance(x, list):
            x = [x]
            
        out = []
        for x_i in x:
            out.append(super(ReLU_fw, self).forward(x_i))

        return out

class Sigmoid_fw(nn.Sigmoid):
    def __init__(self):
        super(Sigmoid_fw, self).__init__()

    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        out = []
        for x_i in x:
            out.append(super(Sigmoid_fw, self).forward(x_i))

        return out

class MaxPool2d_fw(nn.MaxPool2d):
    def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False):
        super(MaxPool2d_fw, self).__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode)

    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        if self.return_indices:
            out = []
            indices = []

            for x_i in x:
                o, i = super(MaxPool2d_fw, self).forward(x_i)
                out.append(o)
                indices.append(i)

            return out, indices
        else:
            out = []
            for x_i in x:
                out.append(super(MaxPool2d_fw, self).forward(x_i))

        return out

class MaxUnpool2d_fw(nn.MaxUnpool2d):
    def __init__(self, kernel_size, stride=None, padding=0):
        super(MaxUnpool2d_fw, self).__init__(kernel_size, stride, padding)

    def forward(self, x, indices):
        if not isinstance(x, list):
            x = [x]
        if not isinstance(x, list):
            indices = [indices]

        # assert len(x) == len(indices)

        if len(x) == len(indices):
            out = []
            for i, x_i in enumerate(x):
                out.append(super(MaxUnpool2d_fw, self).forward(x_i, indices[i]))
        elif len(x) > 1 and len(indices) == 1:
            out = []
            for i, x_i in enumerate(x):
                out.append(super(MaxUnpool2d_fw, self).forward(x_i, indices[0]))
        else:
            raise ValueError('Error!')

        return out

class AdaptiveAvgPool2d_fw(nn.AdaptiveAvgPool2d):
    def __init__(self, output_size):
        super(AdaptiveAvgPool2d_fw, self).__init__(output_size)

    def forward(self, x):
        if not isinstance(x, list):
            x = [x]

        out = []
        for x_i in x:
            out.append(super(AdaptiveAvgPool2d_fw, self).forward(x_i))

        return out

def dot_fw(x, y):
    if not isinstance(x, list):
        x = [x]
    if not isinstance(y, list):
        y = [y]

    out = []
    if len(x) == 1 and len(y) > 1:
        for i, y_i in enumerate(y):
            o = x[0] * y_i
            out.append(o)
    elif len(x) > 1 and len(y) == 1:
        for i, x_i in enumerate(x):
            o = x_i * y[0]
            out.append(o)
    elif len(x) == len(y):
        for i, x_i in enumerate(x):
            o = x_i * y[i]
            out.append(o)
    else:
        raise ValueError('Error')

    return out

def cat_fw(x, y, dim=0):
    if not isinstance(x, list):
        x = [x]
    if not isinstance(y, list):
        y = [y]

    out = []
    if len(x) == 1 and len(y) > 1:
        for i, y_i in enumerate(y):
            o = torch.cat((x[0], y_i), dim=dim)
            out.append(o)
    elif len(x) > 1 and len(y) == 1:
        for i, x_i in enumerate(x):
            o = torch.cat((x_i, y[0]), dim=dim)
            out.append(o)
    elif len(x) == len(y):
        for i, x_i in enumerate(x):
            o = torch.cat((x_i, y[i]), dim=dim)
            out.append(o)
    else:
        raise ValueError('Error')

    return out

def interpolate_fw(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
    if not isinstance(input, list):
        input = [input]

    out = []
    for x in input:
        o = F.interpolate(x, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
        out.append(o)

    return out

def is_father_str(str, str_list):
    for s in str_list:
        if s in str:
            return True
    return False

class SegNet_fw(nn.Module):
    def __init__(self, n_tasks=2, branched='empty', branched_files = None, attention=True, topK=5,  T=(True, True), ablation_file=None):
        super(SegNet_fw, self).__init__()
        # initialise network parameters
        filter = [64, 128, 256, 512, 512]
        self.class_nb = 7
        self.n_tasks = n_tasks
        self.attention = attention

        # define encoder decoder layers
        self.encoder_block = nn.ModuleList([self.conv_layer([3, filter[0]])])
        self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
        for i in range(4):
            self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]]))
            self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]]))

        # define convolution layer
        self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
        self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
        for i in range(4):
            if i == 0:
                self.conv_block_enc.append(self.conv_layer([filter[i + 1], filter[i + 1]]))
                self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]]))
            else:
                self.conv_block_enc.append(nn.Sequential(self.conv_layer([filter[i + 1], filter[i + 1]]),
                                                         self.conv_layer([filter[i + 1], filter[i + 1]])))
                self.conv_block_dec.append(nn.Sequential(self.conv_layer([filter[i], filter[i]]),
                                                         self.conv_layer([filter[i], filter[i]])))

        if attention:
            # define task attention layers
            self.encoder_att = nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]], n_tasks=n_tasks)])
            self.decoder_att = nn.ModuleList([self.att_layer([2 * filter[0], filter[0], filter[0]], n_tasks=n_tasks)])

            for i in range(4):
                self.encoder_att.append(self.att_layer([2 * filter[i + 1], filter[i + 1], filter[i + 1]], n_tasks=n_tasks))
                self.decoder_att.append(self.att_layer([filter[i + 1] + filter[i], filter[i], filter[i]], n_tasks=n_tasks))

            self.encoder_block_att = nn.ModuleList([self.conv_layer([filter[0], filter[1]])])
            self.decoder_block_att = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
            for i in range(4):
                if i < 3:
                    self.encoder_block_att.append(self.conv_layer([filter[i + 1], filter[i + 2]]))
                    self.decoder_block_att.append(self.conv_layer([filter[i + 1], filter[i]]))
                else:
                    self.encoder_block_att.append(self.conv_layer([filter[i + 1], filter[i + 1]]))
                    self.decoder_block_att.append(self.conv_layer([filter[i + 1], filter[i + 1]]))


        if T[0]:
            self.pred_task1 = self.conv_layer([filter[0], self.class_nb], pred=True)

        if T[1]:
            self.pred_task2 = self.conv_layer([filter[0], 1], pred=True)


        # self.pred_task3 = self.conv_layer([filter[0], 3], pred=True)

        # define pooling and unpooling functions
        self.max_pool = MaxPool2d_fw(kernel_size=2, stride=2)

        self.down_sampling = MaxPool2d_fw(kernel_size=2, stride=2, return_indices=True)
        self.up_sampling = MaxUnpool2d_fw(kernel_size=2, stride=2)

        self.logsigma = nn.Parameter(torch.FloatTensor([-0.5, -0.5, -0.5]))

        self._shared_parameters = OrderedDict()

        # obtain the defined task dependent layers
        td_layers = []
        for idx, m in self.named_modules():
            members = m._parameters.items()
            memo = set()
            for k, v in members:
                if v is None or v in memo:
                    continue
                memo.add(v)
                name = idx + ('.' if idx else '') + k

                if is_father_str(name, self.task_dependent_modules_names()):
                    td_layers.append(name[:idx.index('.m_list')])

        self.td_layers = list(set(td_layers))

        # insert task branches
        if branched == 'empty':
            task_branches = []
        elif branched == 'ablation':
            with open(ablation_file, "r") as fp:
                task_branches = json.load(fp)
        elif branched == 'branch':
            if branched_files is not None:
                with open(branched_files, "r") as fp:
                    task_branches = json.load(fp)
                    task_branches = list(task_branches.keys())
                    task_branches = task_branches[:topK]

            else:
                # lw_cos_flood_0d0_ep200
                # task_branches = ['encoder_block_att.0.conv1', 'decoder_block_att.0.bn1', 'encoder_block_att.1.bn1', 'decoder_block_att.2.bn1', 'encoder_block_att.0.bn1',
                #                  'encoder_block_att.3.bn1', 'decoder_block_att.4.bn1', 'encoder_block_att.1.conv1', 'decoder_block_att.3.bn1', 'conv_block_dec.0.bn1',
                #                  'encoder_block_att.2.bn1', 'conv_block_dec.0.conv1', 'decoder_block_att.1.bn1', 'encoder_block_att.3.conv1', 'encoder_block_att.2.conv1',
                #                  'encoder_block_att.4.bn1', 'conv_block_dec.1.bn1', 'decoder_block.0.bn1', 'decoder_block_att.0.conv1', 'decoder_block_att.3.conv1',
                #                  'decoder_block_att.1.conv1', 'decoder_block_att.4.conv1', 'decoder_block_att.2.conv1', 'encoder_block_att.4.conv1', 'decoder_block.1.bn1']

                # S=-0.05 epoch=40 Adam lr 5e-5, approach=joint-train
                task_branches = [
                    'decoder_block.1.bn1',
                    'conv_block_dec.1.bn1',
                    'encoder_block.0.conv1',
                    'conv_block_dec.2.1.bn1',
                    'conv_block_dec.2.0.bn1',
                    'conv_block_dec.1.conv1',
                    'decoder_block.2.bn1',
                    'decoder_block.1.conv1',
                    'encoder_block.0.bn1',
                    'conv_block_enc.0.conv1',
                    'decoder_block.0.bn1',
                    'conv_block_enc.0.bn1',
                    'conv_block_dec.3.1.bn1',
                    'encoder_block.1.bn1',
                    'conv_block_enc.1.bn1',
                    'conv_block_dec.3.0.bn1',
                    'encoder_block.1.conv1',
                    'conv_block_dec.2.1.conv1',
                    'conv_block_dec.2.0.conv1',
                    'conv_block_enc.1.conv1',
                    'encoder_block.2.bn1',
                    'conv_block_enc.2.1.bn1',
                    'conv_block_enc.2.0.bn1',
                    'decoder_block.3.bn1',
                    'decoder_block.2.conv1',
                    'decoder_block.0.conv1',
                    'encoder_block.3.bn1',
                    'encoder_block.2.conv1',
                    'conv_block_dec.4.1.bn1',
                    'conv_block_enc.3.0.bn1',
                    'conv_block_enc.3.1.bn1',
                    'encoder_block.4.bn1',
                    'conv_block_enc.2.0.conv1',
                    'conv_block_enc.4.0.bn1',
                    'conv_block_dec.4.0.bn1',
                    'conv_block_enc.2.1.conv1',
                    'conv_block_dec.0.bn1',
                    'conv_block_enc.4.1.bn1',
                    'encoder_block.3.conv1',
                    'decoder_block.4.bn1',
                    'conv_block_dec.3.1.conv1',
                    'conv_block_enc.3.0.conv1',
                    'conv_block_enc.3.1.conv1',
                    'conv_block_dec.3.0.conv1',
                    'encoder_block.4.conv1',
                    'conv_block_enc.4.0.conv1',
                    'decoder_block.3.conv1',
                    'conv_block_enc.4.1.conv1',
                    'conv_block_dec.0.conv1',
                    'conv_block_dec.4.1.conv1',
                    'decoder_block.4.conv1',
                    'conv_block_dec.4.0.conv1'
                ]

                task_branches = task_branches[:topK]
        else:
            raise ValueError('Not Defined!')

        self.turn(task_branches=task_branches)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def model_size(self, unit='MB'):
        param_size = 0
        for param in self.parameters():
            param_size += param.nelement() * param.element_size()

        if unit == 'MB':
            size_all_out = param_size / 1024 ** 2
        elif unit == 'B':
            size_all_out = param_size
        else:
            raise ValueError(f'Error: Do not support unit {unit}')

        return size_all_out

    def turn(self, task_branches=[]):
        self._shared_parameters.clear()

        if self.n_tasks > 1:
            for idx, m in self.named_modules():
                if idx in task_branches:
                    m.set_n_tasks(n_tasks=self.n_tasks)

        task_branches = task_branches + self.td_layers

        # obtain the shared parameters
        for idx, m in self.named_modules():
            if '.m_list' in idx:
                idx = idx[:idx.index('.m_list')]

            if idx not in task_branches:
                members = m._parameters.items()
                memo = set()
                for k, v in members:
                    if v is None or v in memo:
                        continue
                    memo.add(v)
                    name = idx + ('.' if idx else '') + k
                    self._shared_parameters[name] = v

        # delete logsigma
        del self._shared_parameters['logsigma']

    def shared_parameters(self):
        return self._shared_parameters

    # def shared_modules(self):
    #     return [self.encoder_block, self.decoder_block,
    #             self.conv_block_enc, self.conv_block_dec,
    #             # self.encoder_att, self.decoder_att,
    #             self.encoder_block_att, self.decoder_block_att,
    #             self.down_sampling, self.up_sampling]

    def shared_modules_name(self):
        return ['encoder_block', 'decoder_block',
                'conv_block_enc', 'conv_block_dec',
                'encoder_block_att', 'decoder_block_att',
                'down_sampling', 'up_sampling']

    def task_dependent_modules_names(self):
        return ['encoder_att', 'decoder_att',
                'pred_task1', 'pred_task2']

    def zero_grad_shared_modules(self):
        for name, p in self.shared_parameters().items():
            if p.grad is not None:
                if p.grad.grad_fn is not None:
                    p.grad.detach_()
                else:
                    p.grad.requires_grad_(False)
                p.grad.zero_()

    def add_branch_layers(self, layer_list):
        for idx, m in self.named_modules():
            if idx in layer_list:
                m.set_n_tasks(n_tasks=self.n_tasks)

    def conv_layer(self, channel, pred=False):
        if not pred:
            conv_block = nn.Sequential(
                OrderedDict([
                    ('conv1', Conv2d_fw(in_channels=channel[0], out_channels=channel[1], kernel_size=3, padding=1)),
                    ('bn1', BatchNorm2d_fw(num_features=channel[1])),
                    ('relu1', ReLU_fw(inplace=True))
                ])
            )
        else:
            conv_block = nn.Sequential(
                OrderedDict([
                    ('conv1', Conv2d_fw(in_channels=channel[0], out_channels=channel[0], kernel_size=3, padding=1)),
                    ('conv2', Conv2d_fw(in_channels=channel[0], out_channels=channel[1], kernel_size=1, padding=0))
                ])
            )
        return conv_block

    def att_layer(self, channel, n_tasks=1):
        att_block = nn.Sequential(
            OrderedDict([
                ('conv1', Conv2d_fw(in_channels=channel[0], out_channels=channel[1], kernel_size=1, padding=0, n_tasks=n_tasks)),
                ('bn1', BatchNorm2d_fw(channel[1], n_tasks=n_tasks)),
                ('relu1', ReLU_fw(inplace=True)),
                ('conv2', Conv2d_fw(in_channels=channel[1], out_channels=channel[2], kernel_size=1, padding=0, n_tasks=n_tasks)),
                ('bn2', BatchNorm2d_fw(channel[2], n_tasks=n_tasks)),
                ('Sigmoid2', Sigmoid_fw())
            ])
        )
        return att_block

    def forward(self, x):
        g_encoder, g_decoder, g_maxpool, g_upsampl, indices = ([0] * 5 for _ in range(5))
        for i in range(5):
            g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2))

        # define attention list for tasks

        atten_encoder, atten_decoder = ([0] * 5 for _ in range(2))
        for j in range(5):
            atten_encoder[j], atten_decoder[j] = ([0] * 3 for _ in range(2))

        # define global shared network
        for i in range(5):
            if i == 0:
                g_encoder[i][0] = self.encoder_block[i](x)
                g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
                g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
            else:
                g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1])
                g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
                g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])

        for i in range(5):
            if i == 0:
                g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1])
                g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
                g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
            else:
                g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1])
                g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
                g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])


        if self.attention:
            for j in range(5):
                if j == 0:
                    atten_encoder[j][0] = self.encoder_att[j](g_encoder[j][0])
                    atten_encoder[j][1] = dot_fw(atten_encoder[j][0], g_encoder[j][1])
                    atten_encoder[j][2] = self.encoder_block_att[j](atten_encoder[j][1])
                    atten_encoder[j][2] = self.max_pool(atten_encoder[j][2])
                else:
                    atten_encoder[j][0] = self.encoder_att[j](
                        cat_fw(g_encoder[j][0], atten_encoder[j - 1][2], dim=1))
                    atten_encoder[j][1] = dot_fw(atten_encoder[j][0], g_encoder[j][1])
                    atten_encoder[j][2] = self.encoder_block_att[j](atten_encoder[j][1])
                    atten_encoder[j][2] = self.max_pool(atten_encoder[j][2])

            for j in range(5):
                if j == 0:
                    atten_decoder[j][0] = interpolate_fw(atten_encoder[-1][-1], scale_factor=2, mode='bilinear',
                                                           align_corners=True)
                    atten_decoder[j][0] = self.decoder_block_att[-j - 1](atten_decoder[j][0])
                    atten_decoder[j][1] = self.decoder_att[-j - 1](
                        cat_fw(g_upsampl[j], atten_decoder[j][0], dim=1))
                    atten_decoder[j][2] = dot_fw(atten_decoder[j][1], g_decoder[j][-1])

                else:
                    atten_decoder[j][0] = interpolate_fw(atten_decoder[j - 1][2], scale_factor=2, mode='bilinear',
                                                           align_corners=True)
                    atten_decoder[j][0] = self.decoder_block_att[-j - 1](atten_decoder[j][0])
                    atten_decoder[j][1] = self.decoder_att[-j - 1](
                        cat_fw(g_upsampl[j], atten_decoder[j][0], dim=1))
                    atten_decoder[j][2] = dot_fw(atten_decoder[j][1], g_decoder[j][-1])

            if len(atten_decoder[-1][-1]) == self.n_tasks:
                t1_pred = self.pred_task1(atten_decoder[-1][-1][0])[0]
                t2_pred = self.pred_task2(atten_decoder[-1][-1][1])[0]
            elif len(atten_decoder[-1][-1]) == 1:
                t1_pred = self.pred_task1(atten_decoder[-1][-1][0])[0]
                t2_pred = self.pred_task2(atten_decoder[-1][-1][0])[0]
            else:
                raise ValueError('Error')

        else:
            if len(g_decoder[-1][-1]) == self.n_tasks:
                t1_pred = self.pred_task1(g_decoder[-1][-1][0])[0]
                t2_pred = self.pred_task2(g_decoder[-1][-1][1])[0]
            elif len(g_decoder[-1][-1]) == 1:
                t1_pred = self.pred_task1(g_decoder[-1][-1][0])[0]
                t2_pred = self.pred_task2(g_decoder[-1][-1][0])[0]
            else:
                raise ValueError('Error')

        return [t1_pred, t2_pred], self.logsigma

def parameter_sz(p_list, unit='B'):
    param_size = 0
    for param in p_list:
        param_size += param.nelement() * param.element_size()

    if unit == 'MB':
        size = param_size / 1024 ** 2
    elif unit == 'B':
        size = param_size
    else:
        raise ValueError(f'Error: Do not support unit {unit}')

    return size


if __name__ == '__main__':
    # net = SegNet_fw(branched='empty', attention=False, T=[False, False])
    # normal_model_size = net.model_size()
    # print(f'noraml_na: {normal_model_size:.2f}')
    # net = SegNet_fw(branched='branch', branched_files='./saved/0nothing_Adam_40_fd-0.0_0_40.json', attention=False, topK=35, T=[False, False])
    # model_size = net.model_size()
    # print(f'Recon_na: {model_size:.2f}')
    #
    # net_t1 = SegNet_fw(branched='empty', attention=False, T=[True, False])
    # net_t2 = SegNet_fw(branched='empty', attention=False, T=[False, True])
    #
    # m = net_t1.model_size() + net_t2.model_size()
    # print(f'Single: {m:.2f}')
    #
    # n_tasks = 2
    # roto_param = nn.Parameter(torch.eye(1024))
    # roto_param_size = roto_param.nelement() * roto_param.element_size() * n_tasks / 1024 ** 2
    #
    # print(f'Roto: {roto_param_size + normal_model_size:.2f}')


    def get_layer_dict(m):
        shared_parameters = m.shared_parameters()
        name_list = list(shared_parameters.keys())
        param_list = list(shared_parameters.values())

        layer_dict = {}
        for i, name in enumerate(name_list):
            if '.weight' in name:
                name = name.replace('.weight', '')
            elif '.bias' in name:
                name = name.replace('.bias', '')

            if name not in layer_dict:
                layer_dict[name] = [param_list[i]]
            else:
                layer_dict[name].append(param_list[i])

        return layer_dict


    import numpy as np
    import random

    seed = 0
    np.random.seed(seed)
    random.seed(seed)

    net = SegNet_fw(branched='empty', attention=False)
    layer_dict = get_layer_dict(net)
    layer_names = list(layer_dict.keys())

    topK = 39
    n_layers = len(layer_names)
    permute = np.random.permutation(n_layers)
    layer_name_permute = [layer_names[p] for p in permute]
    random_layer_names = layer_name_permute[:topK]

    saved_file = f'./saved/seed{seed}_random_select_{topK}Layers.json'

    with open(saved_file, "w") as fp:
        json.dump(random_layer_names, fp)


    # # fix the size of models
    branch_net = SegNet_fw(branched='branch', branched_files='./saved/0nothing_Adam_40_fd-0.0_0_40.json', attention=False, topK=39)
    branch_net_sz = branch_net.model_size(unit='B')
    sz = net.model_size(unit='B')
    sz_bound = branch_net_sz - sz
    param_sz = [parameter_sz(layer_dict[name]) for name in layer_name_permute]

    p_sz_total = 0
    n_select_layers = 0
    for i, p_sz in enumerate(param_sz):
        p_sz_total += p_sz
        if p_sz_total > sz_bound:
            n_select_layers = i + 1
            break

    random_layer_names = layer_name_permute[:n_select_layers]
    saved_file = f'./saved/seed{seed}_random_select_comparable_size.json'

    with open(saved_file, "w") as fp:
        json.dump(random_layer_names, fp)

    # net = SegNet_fw(branched='empty', attention=False)
    # x = torch.rand(2, 3, 128, 256)
    # output = net(x)
