import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import resources
import json
import modules_fw_v2 as fw
from collections import OrderedDict
from pathlib import Path

def read_json(fname):
    fname = Path(fname)
    with fname.open('rt') as handle:
        return json.load(handle, object_hook=OrderedDict)


def write_json(content, fname):
    fname = Path(fname)
    with fname.open('wt') as handle:
        json.dump(content, handle, indent=4, sort_keys=False)

class BranchedLayer(nn.Module):

    def __init__(self, operation, mapping):
        super().__init__()
        self.mapping = mapping
        self.path = nn.ModuleDict()
        for ind_list in self.mapping.values():
            for ind in ind_list:
                self.path[str(ind)] = copy.deepcopy(operation)

    def forward(self, x):
        out = {}
        for branch_k in self.mapping.keys():
            for out_branch in self.mapping[branch_k]:
                out[out_branch] = self.path[str(out_branch)](x[branch_k])
        return out

class SupernetLayer(nn.Module):

    def __init__(self, operation, n_ops):
        super().__init__()
        self.path = nn.ModuleList()
        for _ in range(n_ops):
            self.path.append(copy.deepcopy(operation))

    def forward(self, x, op_weights):
        out = sum(op_weights[i] * op(x) for i, op in enumerate(self.path))
        return out

class GumbelSoftmax(nn.Module):

    def __init__(self, dim=None, hard=False):
        super().__init__()
        self.hard = hard
        self.dim = dim

    def forward(self, logits, temperature):
        # known issues with gumbel_softmax for older pytorch versions:
        # https://github.com/pytorch/pytorch/issues/22442
        # https://github.com/pytorch/pytorch/pull/20179
        eps = 1e-10
        gumbels = -(torch.empty_like(logits).exponential_() +
                    eps).log()  # ~Gumbel(0,1)
        # ~Gumbel(logits,temperature)
        gumbels = (logits + gumbels) / temperature
        y_soft = gumbels.softmax(self.dim)

        if self.hard:
            # Straight through.
            index = y_soft.max(self.dim, keepdim=True)[1]
            y_hard = torch.zeros_like(logits).scatter_(self.dim, index, 1.0)
            ret = y_hard - y_soft.detach() + y_soft
        else:
            # Reparameterization trick.
            ret = y_soft
        return ret

#-----------------------------------------------------------------------------------------------------------

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

def conv3x3_fw_v2(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return fw.Conv2d_fw_v2(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)


class Super_SegNet(nn.Module):
    def __init__(self, tasks):
        super(Super_SegNet, self).__init__()
        self.tasks = tasks
        n_tasks = len(tasks)
        # initialise network parameters
        filter = [64, 128, 256, 512, 512]
        self.class_nb = 7

        # define encoder decoder layers
        self.encoder_block = nn.ModuleList([SupernetLayer(self.conv_layer([3, filter[0]]), n_tasks)])
        self.decoder_block = nn.ModuleList([SupernetLayer(self.conv_layer([filter[0], filter[0]]), n_tasks)])
        for i in range(4):
            self.encoder_block.append(SupernetLayer(self.conv_layer([filter[i], filter[i + 1]]), n_tasks))
            self.decoder_block.append(SupernetLayer(self.conv_layer([filter[i + 1], filter[i]]), n_tasks))

        # define convolution layer
        self.conv_block_enc = nn.ModuleList([SupernetLayer(self.conv_layer([filter[0], filter[0]]), n_tasks)])
        self.conv_block_dec = nn.ModuleList([SupernetLayer(self.conv_layer([filter[0], filter[0]]), n_tasks)])
        for i in range(4):
            if i == 0:
                self.conv_block_enc.append(SupernetLayer(self.conv_layer([filter[i + 1], filter[i + 1]]), n_tasks))
                self.conv_block_dec.append(SupernetLayer(self.conv_layer([filter[i], filter[i]]), n_tasks))
            else:
                self.conv_block_enc.append(SupernetLayer(nn.Sequential(self.conv_layer([filter[i + 1], filter[i + 1]]),
                                                         self.conv_layer([filter[i + 1], filter[i + 1]])), n_tasks))
                self.conv_block_dec.append(SupernetLayer(nn.Sequential(self.conv_layer([filter[i], filter[i]]),
                                                         self.conv_layer([filter[i], filter[i]])), n_tasks))

        # define pooling and unpooling functions
        self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2)

        self.decoder = nn.ModuleDict({
            'semantic': self.conv_layer([filter[0], self.class_nb], pred=True),
            'depth': self.conv_layer([filter[0], 1], pred=True)})

        # self.pred_task1 = self.conv_layer([filter[0], self.class_nb], pred=True)
        # self.pred_task2 = self.conv_layer([filter[0], 1], pred=True)

        self.warmup_flag = False
        self.gumbel_temp = None
        self.gumbel_func = GumbelSoftmax(dim=-1, hard=False)

        self._create_alphas()
        self._init_weights()

    def _create_alphas(self):
        """ Create the architecture parameters of the supergraph. """
        n_tasks = len(self.tasks)
        self.alphas = nn.ParameterList([
            nn.Parameter(torch.zeros(n_tasks, n_tasks)) for _ in range(20)
        ])

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def conv_layer(self, channel, pred=False):
        if not pred:
            conv_block = nn.Sequential(
                nn.Conv2d(in_channels=channel[0], out_channels=channel[1], kernel_size=3, padding=1),
                nn.BatchNorm2d(num_features=channel[1]),
                nn.ReLU(inplace=True),
            )
        else:
            conv_block = nn.Sequential(
                nn.Conv2d(in_channels=channel[0], out_channels=channel[0], kernel_size=3, padding=1),
                nn.Conv2d(in_channels=channel[0], out_channels=channel[1], kernel_size=1, padding=0),
            )
        return conv_block

    def _encoder_blocks(self):
        return [self.encoder_block, self.decoder_block, self.conv_block_enc, self.conv_block_dec]

    def _get_weights(self, x, t_idx, idx):
        if self.warmup_flag:
            op_weights = torch.eye(len(self.tasks), out=torch.empty_like(
                x), requires_grad=False)[t_idx, :]
        else:
            op_weights = self.gumbel_func(
                self.alphas[idx][t_idx, :], temperature=self.gumbel_temp)

        return op_weights

    def _block_forward(self, t_idx, idx, x, block):
        op_weights = self._get_weights(x, t_idx, idx)
        out = block(x, op_weights)
        idx += 1
        return idx, out

    def freeze_encoder_bn_running_stats(self):
        for encoder in self._encoder_blocks():
            for m in encoder.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.track_running_stats = False

    def unfreeze_encoder_bn_running_stats(self):
        for encoder in self._encoder_blocks():
            for m in encoder.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.track_running_stats = True

    def reset_encoder_bn_running_stats(self):
        for encoder in self._encoder_blocks():
            for m in encoder.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.reset_running_stats()

    def weight_parameters(self):
        for name, param in self.named_weight_parameters():
            yield param

    def named_weight_parameters(self):
        return filter(lambda x: not x[0].startswith('alphas'),
                      self.named_parameters())

    def arch_parameters(self):
        for name, param in self.named_arch_parameters():
            yield param

    def named_arch_parameters(self):
        return filter(lambda x: x[0].startswith('alphas'),
                      self.named_parameters())

    def get_branch_config(self):
        n_blocks = len(self.alphas)
        n_tasks = len(self.tasks)
        branch_config = torch.empty(n_blocks, n_tasks, device='cpu')
        for b in range(n_blocks):
            alpha_probs = nn.functional.softmax(
                self.alphas[b], dim=-1).to('cpu').detach()
            for t in range(n_tasks):
                branch_config[b, t] = torch.argmax(alpha_probs[t, :])
        return branch_config.numpy().tolist()

    def calculate_flops_lut(self, file_name, input_size):
        # shared_config = torch.zeros(10, len(self.tasks))
        # model = BranchMnistResNet18(tasks=['t1'],
        #                           branch_config=shared_config)
        model = SegNet()
        in_shape = (1, 1, input_size[0], input_size[1])
        n_blocks = len(model.encoder)
        flops = torch.zeros(n_blocks, device='cpu')

        model.eval()
        with torch.no_grad():
            for idx, m in enumerate(model.encoder):
                m = resources.add_flops_counting_methods(m)
                m.start_flops_count()
                cache_inputs = torch.rand(in_shape)
                _ = model(cache_inputs)
                block_flops = m.compute_average_flops_cost()
                m.stop_flops_count()
                flops[idx] = block_flops
        flops_dict = {'per_block_flops': flops.numpy().tolist()}
        del model

        # save the FLOPS to LUT
        write_json(flops_dict, file_name)

        return flops_dict

    def get_flops(self):
        input_size = [32, 32]  # for PASCAL-Context

        filename = Path('flops_MobileNetV2_{}_{}.json'.format(
            input_size[0], input_size[1]))
        if filename.is_file():
            flops_dict = read_json(filename)
        else:
            print('no LUT found, calculating FLOPS...')
            flops_dict = self.calculate_flops_lut(filename, input_size)
        return flops_dict['per_block_flops']

    def forward(self, x, task):
        g_encoder, g_decoder, g_maxpool, g_upsampl, indices = ([0] * 5 for _ in range(5))
        for i in range(5):
            g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2))

        idx = 0
        t_idx = self.tasks.index(task)

        # define global shared network
        for i in range(5):
            if i == 0:
                idx, g_encoder[i][0] = self._block_forward(t_idx, idx, x, self.encoder_block[i])
                idx, g_encoder[i][1] = self._block_forward(t_idx, idx, g_encoder[i][0], self.conv_block_enc[i])
                g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
            else:
                idx, g_encoder[i][0] = self._block_forward(t_idx, idx, g_maxpool[i - 1], self.encoder_block[i])
                idx, g_encoder[i][1] = self._block_forward(t_idx, idx, g_encoder[i][0], self.conv_block_enc[i])
                g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])

        for i in range(5):
            if i == 0:
                g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1])
                idx, g_decoder[i][0] = self._block_forward(t_idx, idx, g_upsampl[i], self.decoder_block[-i - 1])
                idx, g_decoder[i][1] = self._block_forward(t_idx, idx, g_decoder[i][0], self.conv_block_dec[-i - 1])
            else:
                g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1])

                idx, g_decoder[i][0] = self._block_forward(t_idx, idx, g_upsampl[i], self.decoder_block[-i - 1])
                idx, g_decoder[i][1] = self._block_forward(t_idx, idx, g_decoder[i][0], self.conv_block_dec[-i - 1])

        out = {task: self.decoder[task](g_decoder[-1][-1])}

        return out

class SegNet(nn.Module):
    def __init__(self):
        super(SegNet, self).__init__()
        # initialise network parameters
        filter = [64, 128, 256, 512, 512]
        self.class_nb = 7

        # define encoder decoder layers
        self.encoder_block = nn.ModuleList([self.conv_layer([3, filter[0]])])
        self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
        for i in range(4):
            self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]]))
            self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]]))

        # define convolution layer
        self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
        self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
        for i in range(4):
            if i == 0:
                self.conv_block_enc.append(self.conv_layer([filter[i + 1], filter[i + 1]]))
                self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]]))
            else:
                self.conv_block_enc.append(nn.Sequential(self.conv_layer([filter[i + 1], filter[i + 1]]),
                                                         self.conv_layer([filter[i + 1], filter[i + 1]])))
                self.conv_block_dec.append(nn.Sequential(self.conv_layer([filter[i], filter[i]]),
                                                         self.conv_layer([filter[i], filter[i]])))


        self.pred_task1 = self.conv_layer([filter[0], self.class_nb], pred=True)
        self.pred_task2 = self.conv_layer([filter[0], 1], pred=True)
        #self.pred_task3 = self.conv_layer([filter[0], 3], pred=True)

        # define pooling and unpooling functions
        self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2)


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def conv_layer(self, channel, pred=False):
        if not pred:
            conv_block = nn.Sequential(
                nn.Conv2d(in_channels=channel[0], out_channels=channel[1], kernel_size=3, padding=1),
                nn.BatchNorm2d(num_features=channel[1]),
                nn.ReLU(inplace=True),
            )
        else:
            conv_block = nn.Sequential(
                nn.Conv2d(in_channels=channel[0], out_channels=channel[0], kernel_size=3, padding=1),
                nn.Conv2d(in_channels=channel[0], out_channels=channel[1], kernel_size=1, padding=0),
            )
        return conv_block

    def att_layer(self, channel):
        att_block = nn.Sequential(
            nn.Conv2d(in_channels=channel[0], out_channels=channel[1], kernel_size=1, padding=0),
            nn.BatchNorm2d(channel[1]),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=channel[1], out_channels=channel[2], kernel_size=1, padding=0),
            nn.BatchNorm2d(channel[2]),
            nn.Sigmoid(),
        )
        return att_block

    def forward(self, x):
        g_encoder, g_decoder, g_maxpool, g_upsampl, indices = ([0] * 5 for _ in range(5))
        for i in range(5):
            g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2))

        # define attention list for tasks
        atten_encoder, atten_decoder = ([0] * 2 for _ in range(2))
        for i in range(2):
            atten_encoder[i], atten_decoder[i] = ([0] * 5 for _ in range(2))
        for i in range(2):
            for j in range(5):
                atten_encoder[i][j], atten_decoder[i][j] = ([0] * 3 for _ in range(2))

        # define global shared network
        for i in range(5):
            if i == 0:
                g_encoder[i][0] = self.encoder_block[i](x)
                g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
                g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
            else:
                g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1])
                g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
                g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])

        for i in range(5):
            if i == 0:
                g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1])
                g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
                g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
            else:
                g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1])
                g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
                g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])


        # define task prediction layers
        t1_pred = self.pred_task1(g_decoder[-1][-1])
        t2_pred = self.pred_task2(g_decoder[-1][-1])

        return [t1_pred, t2_pred]

if __name__ == '__main__':
    a = Super_SegNet()
