import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import resources
import json
import modules_fw_v2 as fw
from collections import OrderedDict
from pathlib import Path

def read_json(fname):
    fname = Path(fname)
    with fname.open('rt') as handle:
        return json.load(handle, object_hook=OrderedDict)


def write_json(content, fname):
    fname = Path(fname)
    with fname.open('wt') as handle:
        json.dump(content, handle, indent=4, sort_keys=False)

class BranchedLayer(nn.Module):

    def __init__(self, operation, mapping):
        super().__init__()
        self.mapping = mapping
        self.path = nn.ModuleDict()
        for ind_list in self.mapping.values():
            for ind in ind_list:
                self.path[str(ind)] = copy.deepcopy(operation)

    def forward(self, x):
        out = {}
        for branch_k in self.mapping.keys():
            for out_branch in self.mapping[branch_k]:
                out[out_branch] = self.path[str(out_branch)](x[branch_k])
        return out

class SupernetLayer(nn.Module):

    def __init__(self, operation, n_ops):
        super().__init__()
        self.path = nn.ModuleList()
        for _ in range(n_ops):
            self.path.append(copy.deepcopy(operation))

    def forward(self, x, op_weights):
        out = sum(op_weights[i] * op(x) for i, op in enumerate(self.path))
        return out

class GumbelSoftmax(nn.Module):

    def __init__(self, dim=None, hard=False):
        super().__init__()
        self.hard = hard
        self.dim = dim

    def forward(self, logits, temperature):
        # known issues with gumbel_softmax for older pytorch versions:
        # https://github.com/pytorch/pytorch/issues/22442
        # https://github.com/pytorch/pytorch/pull/20179
        eps = 1e-10
        gumbels = -(torch.empty_like(logits).exponential_() +
                    eps).log()  # ~Gumbel(0,1)
        # ~Gumbel(logits,temperature)
        gumbels = (logits + gumbels) / temperature
        y_soft = gumbels.softmax(self.dim)

        if self.hard:
            # Straight through.
            index = y_soft.max(self.dim, keepdim=True)[1]
            y_hard = torch.zeros_like(logits).scatter_(self.dim, index, 1.0)
            ret = y_hard - y_soft.detach() + y_soft
        else:
            # Reparameterization trick.
            ret = y_soft
        return ret

#-----------------------------------------------------------------------------------------------------------

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

def conv3x3_fw_v2(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return fw.Conv2d_fw_v2(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)

class Fist_conv_block(nn.Module):
    def __init__(self, inplanes):
        super(Fist_conv_block, self).__init__()
        self.first_conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
            nn.BatchNorm2d(inplanes),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

    def forward(self, x):
        out = self.first_conv(x)
        return out

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class Bottleneck(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class Generate_ResNet():
    def __init__(self, block, layers, num_classes=1000,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(Generate_ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group

        blocks = []
        first_conv = Fist_conv_block(self.inplanes)
        blocks.append(first_conv)

        layer1 = self._make_layer(block, 64, layers[0])
        blocks.extend(layer1)

        layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        blocks.extend(layer2)

        layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        blocks.extend(layer3)

        layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        blocks.extend(layer4)

        fc = nn.Linear(512 * block.expansion, num_classes)

        blocks.append(fc)
        self.blocks = blocks

    def get_blocks(self):
        return self.blocks

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return layers

def resnet18(num_classes=1000):
    return Generate_ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes).get_blocks()

class SuperMnistResNet18(torch.nn.Module):
    def __init__(self, tasks):
        super(SuperMnistResNet18, self).__init__()
        # self.n_tasks = n_tasks
        self.tasks = tasks
        blocks = resnet18(num_classes=100)
        self.n_blocks = len(blocks)

        self.encoder = nn.ModuleList(SupernetLayer(
            bl, len(self.tasks)) for bl in blocks)


        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))

        self.decoder = nn.ModuleDict({
            task: nn.Linear(100, 10)
            for task in self.tasks})

        # self.feature_extractor.conv1 = torch.nn.Conv2d(1, 64,
        #                                                kernel_size=(7, 7),
        #                                                stride=(2, 2),
        #                                                padding=(3, 3), bias=False)
        # self.relu = nn.ReLU(inplace=True)

        # fc_in_features = self.feature_extractor.fc.in_features
        # self.feature_extractor.fc = torch.nn.Linear(fc_in_features, 100)

        self.warmup_flag = False
        self.gumbel_temp = None
        self.gumbel_func = GumbelSoftmax(dim=-1, hard=False)

        self._create_alphas()
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        # if zero_init_residual:
        #     for m in self.modules():
        #         if isinstance(m, Bottleneck):
        #             nn.init.constant_(m.bn3.weight, 0)
        #         elif isinstance(m, BasicBlock):
        #             nn.init.constant_(m.bn2.weight, 0)

    def freeze_encoder_bn_running_stats(self):
        for m in self.encoder.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.track_running_stats = False

    def unfreeze_encoder_bn_running_stats(self):
        for m in self.encoder.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.track_running_stats = True

    def reset_encoder_bn_running_stats(self):
        for m in self.encoder.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.reset_running_stats()

    def weight_parameters(self):
        for name, param in self.named_weight_parameters():
            yield param

    def named_weight_parameters(self):
        return filter(lambda x: not x[0].startswith('alphas'),
                      self.named_parameters())

    def arch_parameters(self):
        for name, param in self.named_arch_parameters():
            yield param

    def named_arch_parameters(self):
        return filter(lambda x: x[0].startswith('alphas'),
                      self.named_parameters())

    def _create_alphas(self):
        """ Create the architecture parameters of the supergraph. """
        n_tasks = len(self.tasks)
        self.alphas = nn.ParameterList([
            nn.Parameter(torch.zeros(n_tasks, n_tasks)) for _ in self.encoder
        ])

    def shared_modules(self):
        return [self.feature_extractor]

    def zero_grad_shared_modules(self):
        for mm in self.shared_modules():
            mm.zero_grad()

    def forward(self, x, task):
        t_idx = self.tasks.index(task)

        for idx, op in enumerate(self.encoder):
            if self.warmup_flag:
                op_weights = torch.eye(len(self.tasks), out=torch.empty_like(
                    x), requires_grad=False)[t_idx, :]
            else:
                op_weights = self.gumbel_func(
                    self.alphas[idx][t_idx, :], temperature=self.gumbel_temp)

            if idx == self.n_blocks - 1:
                x = self.avg_pool(x)
                x = torch.flatten(x, 1)

            x = op(x, op_weights)

        x = F.relu(x)
        out = {task: self.decoder[task](x)}
        return out

    def get_branch_config(self):
        n_blocks = len(self.alphas)
        n_tasks = len(self.tasks)
        branch_config = torch.empty(n_blocks, n_tasks, device='cpu')
        for b in range(n_blocks):
            alpha_probs = nn.functional.softmax(
                self.alphas[b], dim=-1).to('cpu').detach()
            for t in range(n_tasks):
                branch_config[b, t] = torch.argmax(alpha_probs[t, :])
        return branch_config.numpy().tolist()

    def calculate_flops_lut(self, file_name, input_size):
        shared_config = torch.zeros(10, len(self.tasks))
        model = BranchMnistResNet18(tasks=['t1'],
                                  branch_config=shared_config)
        in_shape = (1, 1, input_size[0], input_size[1])
        n_blocks = len(model.encoder)
        flops = torch.zeros(n_blocks, device='cpu')

        model.eval()
        with torch.no_grad():
            for idx, m in enumerate(model.encoder):
                m = resources.add_flops_counting_methods(m)
                m.start_flops_count()
                cache_inputs = torch.rand(in_shape)
                _ = model(cache_inputs)
                block_flops = m.compute_average_flops_cost()
                m.stop_flops_count()
                flops[idx] = block_flops
        flops_dict = {'per_block_flops': flops.numpy().tolist()}
        del model

        # save the FLOPS to LUT
        write_json(flops_dict, file_name)

        return flops_dict

    def get_flops(self):
        input_size = [32, 32]  # for PASCAL-Context

        filename = Path('flops_MobileNetV2_{}_{}.json'.format(
            input_size[0], input_size[1]))
        if filename.is_file():
            flops_dict = read_json(filename)
        else:
            print('no LUT found, calculating FLOPS...')
            flops_dict = self.calculate_flops_lut(filename, input_size)
        return flops_dict['per_block_flops']

class BranchMnistResNet18_v2(torch.nn.Module):
    def __init__(self, tasks, branch_config, branched='empty', topK=None):
        super(BranchMnistResNet18_v2, self).__init__()
        # self.n_tasks = n_tasks
        self.tasks = tasks
        self.n_tasks = len(tasks)

        self._shared_parameters = OrderedDict()
        self._task_specific_parameters = OrderedDict()

        self.branch_config = branch_config
        mappings = self._get_branch_mappings()

        blocks = resnet18(num_classes=100)
        self.n_blocks = len(blocks)

        self.encoder = nn.Sequential(*[BranchedLayer(
            bl, ma) for bl, ma in zip(blocks, mappings)])

        self.decoder = nn.ModuleDict({
            task: nn.Linear(100, 10)
            for task in self.tasks})

        self.avg_pool = fw.AdaptiveAvgPool2d_fw_v2((1, 1))

        # self.feature_extractor.conv1 = torch.nn.Conv2d(1, 64,
        #                                                kernel_size=(7, 7),
        #                                                stride=(2, 2),
        #                                                padding=(3, 3), bias=False)
        # self.relu = nn.ReLU(inplace=True)

        # fc_in_features = self.feature_extractor.fc.in_features
        # self.feature_extractor.fc = torch.nn.Linear(fc_in_features, 100)

        self._init_weights()

    def _get_branch_mappings(self):
        """
        Calculates branch mappings from the branch_config. `mappings` is a list of dicts mapping
        the index of an input branch to indices of output branches. For example:
        mappings = [
            {0: [0]},
            {0: [0, 1]},
            {0: [0, 2], 1: [1, 3]},
            {0: [1], 1: [2], 2: [3], 3: [0]}
        ]
        """

        def get_partition(layer_config):
            partition = []
            for t in range(len(self.tasks)):
                s = {i for i, x in enumerate(layer_config) if x == t}
                if not len(s) == 0:
                    partition.append(s)
            return partition

        def make_refinement(partition, ancestor):
            """ make `partition` a refinement of `ancestor` """
            refinement = []
            for part_1 in partition:
                for part_2 in ancestor:
                    inter = part_1.intersection(part_2)
                    if not len(inter) == 0:
                        refinement.append(inter)
            return refinement

        task_grouping = [set(range(len(self.tasks))), ]

        refinement_all = []
        for layer_idx, layer_config in enumerate(self.branch_config):

            partition = get_partition(layer_config)
            refinement = make_refinement(partition, task_grouping)
            refinement_all.append(refinement)

        return refinement_all

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        # if zero_init_residual:
        #     for m in self.modules():
        #         if isinstance(m, Bottleneck):
        #             nn.init.constant_(m.bn3.weight, 0)
        #         elif isinstance(m, BasicBlock):
        #             nn.init.constant_(m.bn2.weight, 0)

    def _create_alphas(self):
        """ Create the architecture parameters of the supergraph. """
        n_tasks = len(self.tasks)
        self.alphas = nn.ParameterList([
            nn.Parameter(torch.zeros(n_tasks, n_tasks)) for _ in self.feature_extractor.encoder
        ])

    def shared_modules(self):
        return [self.feature_extractor]

    def zero_grad_shared_modules(self):
        for mm in self.shared_modules():
            mm.zero_grad()

    def forward(self, x):
        out = {0: x}
        for i, layer in enumerate(self.encoder):
            if i == self.n_blocks - 1:
                out = self.avg_pool(out)
                out = fw.flatten_fw_v2(out, 1)

            out = layer(out)

        output = {task: self.decoder[task](out[self.x_decoder_mapping[task]]) for task in self.tasks}

        return output

class BranchMnistResNet18(torch.nn.Module):
    def __init__(self, tasks, branch_config):
        super(BranchMnistResNet18, self).__init__()
        # self.n_tasks = n_tasks
        self.tasks = tasks
        self.branch_config = branch_config

        mappings = self._get_branch_mappings()

        blocks = resnet18(num_classes=100)
        self.n_blocks = len(blocks)

        self.encoder = nn.Sequential(*[BranchedLayer(
            bl, ma) for bl, ma in zip(blocks, mappings)])

        self.decoder = nn.ModuleDict({
            task: nn.Linear(100, 10)
            for task in self.tasks})

        self.avg_pool = fw.AdaptiveAvgPool2d_fw_v2((1, 1))

        # self.feature_extractor.conv1 = torch.nn.Conv2d(1, 64,
        #                                                kernel_size=(7, 7),
        #                                                stride=(2, 2),
        #                                                padding=(3, 3), bias=False)
        # self.relu = nn.ReLU(inplace=True)

        # fc_in_features = self.feature_extractor.fc.in_features
        # self.feature_extractor.fc = torch.nn.Linear(fc_in_features, 100)

        self._init_weights()

    def _get_branch_mappings(self):
        """
        Calculates branch mappings from the branch_config. `mappings` is a list of dicts mapping
        the index of an input branch to indices of output branches. For example:
        mappings = [
            {0: [0]},
            {0: [0, 1]},
            {0: [0, 2], 1: [1, 3]},
            {0: [1], 1: [2], 2: [3], 3: [0]}
        ]
        """

        def get_partition(layer_config):
            partition = []
            for t in range(len(self.tasks)):
                s = {i for i, x in enumerate(layer_config) if x == t}
                if not len(s) == 0:
                    partition.append(s)
            return partition

        def make_refinement(partition, ancestor):
            """ make `partition` a refinement of `ancestor` """
            refinement = []
            for part_1 in partition:
                for part_2 in ancestor:
                    inter = part_1.intersection(part_2)
                    if not len(inter) == 0:
                        refinement.append(inter)
            return refinement

        task_grouping = [set(range(len(self.tasks))), ]

        mappings = []
        for layer_idx, layer_config in enumerate(self.branch_config):

            partition = get_partition(layer_config)
            refinement = make_refinement(partition, task_grouping)

            out_dict = {}
            for prev_idx, prev in enumerate(task_grouping):
                out_dict[prev_idx] = []
                for curr_idx, curr in enumerate(refinement):
                    if curr.issubset(prev):
                        out_dict[prev_idx].append(curr_idx)

            task_grouping = refinement
            mappings.append(out_dict)

        self.x_decoder_mapping = {
            task: [t_idx in gr for gr in task_grouping].index(True)
            for t_idx, task in enumerate(self.tasks)
        }

        return mappings

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        # if zero_init_residual:
        #     for m in self.modules():
        #         if isinstance(m, Bottleneck):
        #             nn.init.constant_(m.bn3.weight, 0)
        #         elif isinstance(m, BasicBlock):
        #             nn.init.constant_(m.bn2.weight, 0)

    def _create_alphas(self):
        """ Create the architecture parameters of the supergraph. """
        n_tasks = len(self.tasks)
        self.alphas = nn.ParameterList([
            nn.Parameter(torch.zeros(n_tasks, n_tasks)) for _ in self.feature_extractor.encoder
        ])

    def shared_modules(self):
        return [self.feature_extractor]

    def zero_grad_shared_modules(self):
        for mm in self.shared_modules():
            mm.zero_grad()

    def forward(self, x):
        out = {0: x}
        for i, layer in enumerate(self.encoder):
            if i == self.n_blocks - 1:
                out = self.avg_pool(out)
                out = fw.flatten_fw_v2(out, 1)

            out = layer(out)

        output = {task: self.decoder[task](out[self.x_decoder_mapping[task]]) for task in self.tasks}

        return output

if __name__ == '__main__':
    a = SuperMnistResNet18(tasks=['t1', 't2'])
