import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from rotograd import RotoGrad
from collections import OrderedDict

# task_branches = []


# def is_branches():

def conv_layer(channel, pred=False):
    if not pred:
        conv_block = nn.Sequential(
            OrderedDict([
                ('conv1', nn.Conv2d(in_channels=channel[0], out_channels=channel[1], kernel_size=3, padding=1)),
                ('bn1', nn.BatchNorm2d(num_features=channel[1])),
                ('relu1', nn.ReLU(inplace=True))
            ])
        )
    else:
        conv_block = nn.Sequential(
            OrderedDict([
                ('conv1', nn.Conv2d(in_channels=channel[0], out_channels=channel[0], kernel_size=3, padding=1)),
                ('conv2', nn.Conv2d(in_channels=channel[0], out_channels=channel[1], kernel_size=1, padding=0))
            ])
        )
    return conv_block

class norm_layer(nn.Module):
    def __init__(self, p=2, dim=None, keepdim=False):
        super(norm_layer, self).__init__()
        self.p = p
        self.dim = dim
        self.keepdim = keepdim

    def forward(self, x):
        x = x / torch.norm(x, p=self.p, dim=self.dim, keepdim=self.keepdim)
        return x

class SegNet(nn.Module):
    def __init__(self):
        super(SegNet, self).__init__()
        # initialise network parameters
        filter = [64, 128, 256, 512, 512]
        self.class_nb = 13

        # define encoder decoder layers
        self.encoder_block = nn.ModuleList([conv_layer([3, filter[0]])])
        self.decoder_block = nn.ModuleList([conv_layer([filter[0], filter[0]])])
        for i in range(4):
            self.encoder_block.append(conv_layer([filter[i], filter[i + 1]]))
            self.decoder_block.append(conv_layer([filter[i + 1], filter[i]]))

        # define convolution layer
        self.conv_block_enc = nn.ModuleList([conv_layer([filter[0], filter[0]])])
        self.conv_block_dec = nn.ModuleList([conv_layer([filter[0], filter[0]])])
        for i in range(4):
            if i == 0:
                self.conv_block_enc.append(conv_layer([filter[i + 1], filter[i + 1]]))
                self.conv_block_dec.append(conv_layer([filter[i], filter[i]]))
            else:
                self.conv_block_enc.append(nn.Sequential(conv_layer([filter[i + 1], filter[i + 1]]),
                                                         conv_layer([filter[i + 1], filter[i + 1]])))
                self.conv_block_dec.append(nn.Sequential(conv_layer([filter[i], filter[i]]),
                                                         conv_layer([filter[i], filter[i]])))

        # define task attention layers
        self.encoder_att = nn.ModuleList([nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])])])
        self.decoder_att = nn.ModuleList([nn.ModuleList([self.att_layer([2 * filter[0], filter[0], filter[0]])])])
        self.encoder_block_att = nn.ModuleList([conv_layer([filter[0], filter[1]])])
        self.decoder_block_att = nn.ModuleList([conv_layer([filter[0], filter[0]])])

        for j in range(3):
            if j < 2:
                self.encoder_att.append(nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])]))
                self.decoder_att.append(nn.ModuleList([self.att_layer([2 * filter[0], filter[0], filter[0]])]))
            for i in range(4):
                self.encoder_att[j].append(self.att_layer([2 * filter[i + 1], filter[i + 1], filter[i + 1]]))
                self.decoder_att[j].append(self.att_layer([filter[i + 1] + filter[i], filter[i], filter[i]]))

        for i in range(4):
            if i < 3:
                self.encoder_block_att.append(conv_layer([filter[i + 1], filter[i + 2]]))
                self.decoder_block_att.append(conv_layer([filter[i + 1], filter[i]]))
            else:
                self.encoder_block_att.append(conv_layer([filter[i + 1], filter[i + 1]]))
                self.decoder_block_att.append(conv_layer([filter[i + 1], filter[i + 1]]))

        # define pooling and unpooling functions
        self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2)

        # self.logsigma = nn.Parameter(torch.FloatTensor([-0.5, -0.5, -0.5]))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def shared_modules(self):
        return [self.encoder_block, self.decoder_block,
                self.conv_block_enc, self.conv_block_dec,
                # self.encoder_att, self.decoder_att,
                self.encoder_block_att, self.decoder_block_att,
                self.down_sampling, self.up_sampling]

    def shared_modules_name(self):
        return ['encoder_block', 'decoder_block',
                'conv_block_enc', 'conv_block_dec',
                'encoder_block_att', 'decoder_block_att',
                'down_sampling', 'up_sampling']

    def task_dependent_modules(self):
        return [self.encoder_att, self.decoder_att,
                self.pred_task1, self.pred_task2, self.pred_task3]

    def zero_grad_shared_modules(self):
        for mm in self.shared_modules():
            mm.zero_grad()

    def att_layer(self, channel):
        att_block = nn.Sequential(
            OrderedDict([
                ('conv1', nn.Conv2d(in_channels=channel[0], out_channels=channel[1], kernel_size=1, padding=0)),
                ('bn1', nn.BatchNorm2d(channel[1])),
                ('relu1', nn.ReLU(inplace=True)),
                ('conv2', nn.Conv2d(in_channels=channel[1], out_channels=channel[2], kernel_size=1, padding=0)),
                ('bn2', nn.BatchNorm2d(channel[2])),
                ('Sigmoid2', nn.Sigmoid())
            ])
        )
        return att_block

    def forward(self, x):
        g_encoder, g_decoder, g_maxpool, g_upsampl, indices = ([0] * 5 for _ in range(5))
        for i in range(5):
            g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2))

        # define attention list for tasks
        atten_encoder, atten_decoder = ([0] * 3 for _ in range(2))
        for i in range(3):
            atten_encoder[i], atten_decoder[i] = ([0] * 5 for _ in range(2))
        for i in range(3):
            for j in range(5):
                atten_encoder[i][j], atten_decoder[i][j] = ([0] * 3 for _ in range(2))

        # define global shared network
        for i in range(5):
            if i == 0:
                g_encoder[i][0] = self.encoder_block[i](x)
                g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
                g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
            else:
                g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1])
                g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
                g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])

        for i in range(5):
            if i == 0:
                g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1])
                g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
                g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
            else:
                g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1])
                g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
                g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])

        # # define task dependent attention module
        # for i in range(3):
        #     for j in range(5):
        #         if j == 0:
        #             atten_encoder[i][j][0] = self.encoder_att[i][j](g_encoder[j][0])
        #             atten_encoder[i][j][1] = (atten_encoder[i][j][0]) * g_encoder[j][1]
        #             atten_encoder[i][j][2] = self.encoder_block_att[j](atten_encoder[i][j][1])
        #             atten_encoder[i][j][2] = F.max_pool2d(atten_encoder[i][j][2], kernel_size=2, stride=2)
        #         else:
        #             atten_encoder[i][j][0] = self.encoder_att[i][j](
        #                 torch.cat((g_encoder[j][0], atten_encoder[i][j - 1][2]), dim=1))
        #             atten_encoder[i][j][1] = (atten_encoder[i][j][0]) * g_encoder[j][1]
        #             atten_encoder[i][j][2] = self.encoder_block_att[j](atten_encoder[i][j][1])
        #             atten_encoder[i][j][2] = F.max_pool2d(atten_encoder[i][j][2], kernel_size=2, stride=2)
        #
        #     for j in range(5):
        #         if j == 0:
        #             atten_decoder[i][j][0] = F.interpolate(atten_encoder[i][-1][-1], scale_factor=2, mode='bilinear',
        #                                                    align_corners=True)
        #             atten_decoder[i][j][0] = self.decoder_block_att[-j - 1](atten_decoder[i][j][0])
        #             atten_decoder[i][j][1] = self.decoder_att[i][-j - 1](
        #                 torch.cat((g_upsampl[j], atten_decoder[i][j][0]), dim=1))
        #             atten_decoder[i][j][2] = (atten_decoder[i][j][1]) * g_decoder[j][-1]
        #         else:
        #             atten_decoder[i][j][0] = F.interpolate(atten_decoder[i][j - 1][2], scale_factor=2, mode='bilinear',
        #                                                    align_corners=True)
        #             atten_decoder[i][j][0] = self.decoder_block_att[-j - 1](atten_decoder[i][j][0])
        #             atten_decoder[i][j][1] = self.decoder_att[i][-j - 1](
        #                 torch.cat((g_upsampl[j], atten_decoder[i][j][0]), dim=1))
        #             atten_decoder[i][j][2] = (atten_decoder[i][j][1]) * g_decoder[j][-1]


        # t1_pred = self.pred_task1(atten_decoder[0][-1][-1])
        # t2_pred = self.pred_task2(atten_decoder[1][-1][-1])
        # t3_pred = self.pred_task3(atten_decoder[2][-1][-1])
        #
        # t3_pred = t3_pred / torch.norm(t3_pred, p=2, dim=1, keepdim=True)

        return g_decoder[-1][-1]

def SegNet_RotoGrad(latent_size=1024):
    class_nb = 13
    class_nb = 13
    filter = [64, 128, 256, 512, 512]
    
    backbone = SegNet()

    pred_task1 = conv_layer([filter[0], class_nb], pred=True)
    pred_task2 = conv_layer([filter[0], 1], pred=True)
    pred_task3 = nn.Sequential(conv_layer([filter[0], 3], pred=True), norm_layer(p=2, dim=1, keepdim=True))

    rotoGrad = RotoGrad(backbone=backbone, heads=[pred_task1, pred_task2, pred_task3], latent_size=latent_size)

    return rotoGrad
